示例#1
0
    def test_merge_datasets(self):
        data = create_test_data()

        actual = xr.merge([data[['var1']], data[['var2']]])
        expected = data[['var1', 'var2']]
        assert actual.identical(expected)

        actual = xr.merge([data, data])
        assert actual.identical(data)
示例#2
0
    def test_merge_no_conflicts_broadcast(self):
        datasets = [xr.Dataset({'x': ('y', [0])}), xr.Dataset({'x': np.nan})]
        actual = xr.merge(datasets)
        expected = xr.Dataset({'x': ('y', [0])})
        assert expected.identical(actual)

        datasets = [xr.Dataset({'x': ('y', [np.nan])}), xr.Dataset({'x': 0})]
        actual = xr.merge(datasets)
        assert expected.identical(actual)
示例#3
0
    def test_merge_no_conflicts_single_var(self):
        ds1 = xr.Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})
        ds2 = xr.Dataset({'a': ('x', [2, 3]), 'x': [1, 2]})
        expected = xr.Dataset({'a': ('x', [1, 2, 3]), 'x': [0, 1, 2]})
        assert expected.identical(xr.merge([ds1, ds2],
                                           compat='no_conflicts'))
        assert expected.identical(xr.merge([ds2, ds1],
                                           compat='no_conflicts'))
        assert ds1.identical(xr.merge([ds1, ds2],
                                      compat='no_conflicts',
                                      join='left'))
        assert ds2.identical(xr.merge([ds1, ds2],
                                      compat='no_conflicts',
                                      join='right'))
        expected = xr.Dataset({'a': ('x', [2]), 'x': [1]})
        assert expected.identical(xr.merge([ds1, ds2],
                                           compat='no_conflicts',
                                           join='inner'))

        with pytest.raises(xr.MergeError):
            ds3 = xr.Dataset({'a': ('x', [99, 3]), 'x': [1, 2]})
            xr.merge([ds1, ds3], compat='no_conflicts')

        with pytest.raises(xr.MergeError):
            ds3 = xr.Dataset({'a': ('y', [2, 3]), 'y': [1, 2]})
            xr.merge([ds1, ds3], compat='no_conflicts')
示例#4
0
文件: dual.py 项目: landlab/landlab
    def __init__(self, **kwds):
        node_at_cell = kwds.pop("node_at_cell", None)
        nodes_at_face = kwds.pop("nodes_at_face", None)

        update_node_at_cell(self.ds, node_at_cell)
        update_nodes_at_face(self.ds, nodes_at_face)

        rename = {
            "mesh": "dual",
            "node": "corner",
            "link": "face",
            "patch": "cell",
            "x_of_node": "x_of_corner",
            "y_of_node": "y_of_corner",
            "nodes_at_link": "corners_at_face",
            "links_at_patch": "faces_at_cell",
            "max_patch_links": "max_cell_faces",
        }
        self._ds = xr.merge([self._ds, self._dual.ds.rename(rename)])

        self._origin = (0.0, 0.0)

        self._frozen = False
        self.freeze()

        if kwds.get("sort", True):
            self.sort()
示例#5
0
def rinexobs2(fn: Path,
              use: Sequence[str] = None,
              tlim: Tuple[datetime, datetime] = None,
              useindicators: bool = False,
              meas: Sequence[str] = None,
              verbose: bool = False,
              *,
              fast: bool = True,
              interval: Union[float, int, timedelta] = None) -> xarray.Dataset:

    if isinstance(use, str):
        use = [use]

    if use is None or not use[0].strip():
        use = ('C', 'E', 'G', 'J', 'R', 'S')

    obs = xarray.Dataset({}, coords={'time': [], 'sv': []})
    attrs: Dict[str, Any] = {}
    for u in use:
        o = rinexsystem2(fn, system=u, tlim=tlim,
                         useindicators=useindicators, meas=meas,
                         verbose=verbose,
                         fast=fast, interval=interval)
        if len(o) > 0:
            attrs = o.attrs
            obs = xarray.merge((obs, o))

    obs.attrs = attrs

    return obs
示例#6
0
    def __init__(self, **kwds):
        node_at_cell = kwds.pop('node_at_cell', None)
        nodes_at_face = kwds.pop('nodes_at_face', None)

        update_node_at_cell(self.ds, node_at_cell)
        update_nodes_at_face(self.ds, nodes_at_face)

        rename = {
            'mesh': 'dual',
            'node': 'corner',
            'link': 'face',
            'patch': 'cell',
            'x_of_node': 'x_of_corner',
            'y_of_node': 'y_of_corner',
            'nodes_at_link': 'corners_at_face',
            'links_at_patch': 'faces_at_cell',
            'max_patch_links': 'max_cell_faces',
        }
        self._ds = xr.merge([self._ds, self._dual.ds.rename(rename)])

        self._origin = (0., 0.)

        self._frozen = False
        self.freeze()

        if kwds.get('sort', True):
            self.sort()
示例#7
0
 def test_merge_returns_merged_array(self):
     test_array = self.array.copy()
     test_array.name = 'test'
     test_dataset = xr.merge([test_array, self.array])
     merged_array = test_dataset.to_array(name='merged_array')
     self.array.pp.grid = self.grid
     returned_array = self.array.pp.merge(test_array)
     np.testing.assert_equal(merged_array.values, returned_array.values)
示例#8
0
    def test_merge_no_conflicts_multi_var(self):
        data = create_test_data()
        data1 = data.copy(deep=True)
        data2 = data.copy(deep=True)

        expected = data[['var1', 'var2']]
        actual = xr.merge([data1.var1, data2.var2], compat='no_conflicts')
        assert expected.identical(actual)

        data1['var1'][:, :5] = np.nan
        data2['var1'][:, 5:] = np.nan
        data1['var2'][:4, :] = np.nan
        data2['var2'][4:, :] = np.nan
        del data2['var3']

        actual = xr.merge([data1, data2], compat='no_conflicts')
        assert data.equals(actual)
示例#9
0
    def test_merge_attr_retention(self):
        da1 = create_test_dataarray_attrs(var='var1')
        da2 = create_test_dataarray_attrs(var='var2')
        da2.attrs = {'wrong': 'attributes'}
        original_attrs = da1.attrs

        # merge currently discards attrs, and the global keep_attrs
        # option doesn't affect this
        result = merge([da1, da2])
        assert result.attrs == original_attrs
示例#10
0
    def convert_species_bit_flags(data: Dict) -> xr.Dataset:
        """
        Convert the int32 species flags to a dataset of distinct flags

        Parameters
        ----------
        data
            Dictionary of input data as returned by `load_data`

        Returns
        -------
            Dataset of the index bit flags
        """
        flags = dict()
        flags['separation_method'] = [0, 1, 2]
        flags['one_chan_aerosol_corr'] = 3
        flags['no_935_aerosol_corr'] = 4
        flags['Large_1020_OD'] = 5
        flags['NO2_Extrap'] = 6
        flags['Water_vapor_ratio'] = [7, 8, 9, 10]
        flags['Cloud_Bit_1'] = 11
        flags['Cloud_Bit_2'] = 12
        flags['No_H2O_Corr'] = 13
        flags['In_Troposphere'] = 14

        separation_method = dict()
        separation_method['no_aerosol_method'] = 0
        separation_method['trans_no_aero_to_five_chan'] = 1
        separation_method['standard_method'] = 2
        separation_method['trans_five_chan_to_low'] = 3
        separation_method['four_chan_method'] = 4
        separation_method['trans_four_chan_to_three_chan'] = 5
        separation_method['three_chan_method'] = 6
        separation_method['extension_method'] = 7

        f = dict()
        for key in flags.keys():
            if hasattr(flags[key], '__len__'):
                if key == 'separation_method':
                    for k in separation_method.keys():
                        temp = data['ProfileInfVec'] & np.sum([2 ** k for k in flags[key]])
                        f[k] = temp == separation_method[k]
                else:
                    temp = data['ProfileInfVec'] & np.sum([2 ** k for k in flags[key]])
                    f[key] = temp >> flags[key][0]  # shift flag to save only significant bits
            else:
                f[key] = (data['ProfileInfVec'] & 2 ** flags[key]) > 0

        xr_data = []
        time = pd.to_timedelta(data['mjd'], 'D') + pd.Timestamp('1858-11-17')
        for key in f.keys():
            xr_data.append(xr.DataArray(f[key], coords=[time, data['Alt_Grid'][0:140]], dims=['time', 'Alt_Grid'],
                                        name=key))

        return xr.merge(xr_data)
示例#11
0
def nan_ds():
    def make_data(a, b):
        return np.cos(a) + b, a + np.sin(b)
    r = Runner(make_data, ('a10sum', 'b10sum'))
    ds1 = r.run_combos((('a', np.linspace(1, 3, 10)),
                        ('b', np.linspace(1, 3, 10))))
    ds2 = r.run_combos((('a', np.linspace(1.5, 3.5, 10)),
                        ('b', np.linspace(1, 3, 10))))
    ds3 = r.run_combos((('a', np.linspace(4, 6, 10)),
                        ('b', np.linspace(4, 6, 10))))
    return xr.merge([ds1, ds2, ds3])
示例#12
0
def merge(ds_1: DatasetLike.TYPE,
          ds_2: DatasetLike.TYPE,
          ds_3: DatasetLike.TYPE = None,
          ds_4: DatasetLike.TYPE = None,
          join: str = 'outer',
          compat: str = 'no_conflicts') -> xr.Dataset:
    """
    Merge up to four datasets to produce a new dataset with combined variables from each input dataset.

    This is a wrapper for the ``xarray.merge()`` function.

    For documentation refer to xarray documentation at
    http://xarray.pydata.org/en/stable/generated/xarray.Dataset.merge.html#xarray.Dataset.merge

    The *compat* argument indicates how to compare variables of the same name for potential conflicts:

    * "broadcast_equals": all values must be equal when variables are broadcast
      against each other to ensure common dimensions.
    * "equals": all values and dimensions must be the same.
    * "identical": all values, dimensions and attributes must be the same.
    * "no_conflicts": only values which are not null in both datasets must be equal.
      The returned dataset then contains the combination of all non-null values.

    :param ds_1: The first input dataset.
    :param ds_2: The second input dataset.
    :param ds_3: An optional 3rd input dataset.
    :param ds_4: An optional 4th input dataset.
    :param join: How to combine objects with different indexes.
    :param compat: How to compare variables of the same name for potential conflicts.
    :return: A new dataset with combined variables from each input dataset.
    """

    ds_1 = DatasetLike.convert(ds_1)
    ds_2 = DatasetLike.convert(ds_2)
    ds_3 = DatasetLike.convert(ds_3)
    ds_4 = DatasetLike.convert(ds_4)

    datasets = []
    for ds in (ds_1, ds_2, ds_3, ds_4):
        if ds is not None:
            included = False
            for ds2 in datasets:
                if ds is ds2:
                    included = True
            if not included:
                datasets.append(ds)

    if len(datasets) == 0:
        raise ValidationError('At least two different datasets must be given')
    elif len(datasets) == 1:
        return datasets[0]
    else:
        return xr.merge(datasets, compat=compat, join=join)
示例#13
0
 def test_merge_fill_value(self, fill_value):
     ds1 = xr.Dataset({'a': ('x', [1, 2]), 'x': [0, 1]})
     ds2 = xr.Dataset({'b': ('x', [3, 4]), 'x': [1, 2]})
     if fill_value == dtypes.NA:
         # if we supply the default, we expect the missing value for a
         # float array
         fill_value = np.nan
     expected = xr.Dataset({'a': ('x', [1, 2, fill_value]),
                            'b': ('x', [fill_value, 3, 4])},
                           {'x': [0, 1, 2]})
     assert expected.identical(ds1.merge(ds2, fill_value=fill_value))
     assert expected.identical(ds2.merge(ds1, fill_value=fill_value))
     assert expected.identical(xr.merge([ds1, ds2], fill_value=fill_value))
示例#14
0
def __get_maximum_storage_and_corresponding_dates(start_year:int, end_year:int, data_manager:DataManager, storage_varname=""):
    cache_file_current = "cache_{}-{}_calculate_flood_storage_{}.nc".format(start_year, end_year, storage_varname)
    cache_file_current = Path(cache_file_current)

    # if the variables were calculated already
    if cache_file_current.exists():
        ds = xarray.open_dataset(str(cache_file_current))
    else:
        data_current = data_manager.get_min_max_avg_for_period(
            start_year=start_year, end_year=end_year, varname_internal=storage_varname
        )

        ds = xarray.merge([da for da in data_current.values()])
        ds.to_netcdf(str(cache_file_current))

    return ds
示例#15
0
文件: manage.py 项目: jcmgray/xyzpy
def merge_sync_conflict_datasets(base_name, engine='h5netcdf',
                                 combine_first=False):
    """Glob files based on `base_name`, merge them, save this new dataset if
    it contains new info, then clean up the conflicts.

    Parameters
    ----------
        base_name : str
            Base file name to glob on - should include '*'.
        engine : str , optional
            Load and save engine used by xarray.
    """
    fnames = glob(base_name)
    if len(fnames) < 2:
        print('Nothing to do - need multiple files to merge.')
        return

    # make sure first filename is the shortest -> assumed original
    fnames.sort(key=len)

    print("Merging:\n{}\ninto ->\n{}\n".format(fnames, fnames[0]))

    def load_dataset(fname):
        return load_ds(fname, engine=engine)

    datasets = list(map(load_dataset, fnames))

    # combine all the conflicts
    if combine_first:
        full_dataset = datasets[0]
        for ds in datasets[1:]:
            full_dataset = full_dataset.combine_first(ds)
    else:
        full_dataset = xr.merge(datasets)

    # save new dataset?
    if full_dataset.identical(datasets[0]):
        # nothing to do
        pass
    else:
        save_ds(full_dataset, fnames[0], engine=engine)

    # clean up conflicts
    for fname in fnames[1:]:
        os.remove(fname)
示例#16
0
def _epoch(data: xarray.Dataset, raw: str,
           hdr: Dict[str, Any],
           time: datetime,
           sv: List[str],
           useindicators: bool,
           verbose: bool) -> xarray.Dataset:
    """
    block processing of each epoch (time step)
    """
    darr = np.atleast_2d(np.genfromtxt(io.BytesIO(raw.encode('ascii')),
                                       delimiter=(14, 1, 1) * hdr['Fmax']))
# %% assign data for each time step
    for sk in hdr['fields']:  # for each satellite system type (G,R,S, etc.)
        # satellite indices "si" to extract from this time's measurements
        si = [i for i, s in enumerate(sv) if s[0] in sk]
        if len(si) == 0:  # no SV of this system "sk" at this time
            continue

        # measurement indices "di" to extract at this time step
        di = hdr['fields_ind'][sk]
        garr = darr[si, :]
        garr = garr[:, di]

        gsv = np.array(sv)[si]

        dsf: Dict[str, tuple] = {}
        for i, k in enumerate(hdr['fields'][sk]):
            dsf[k] = (('time', 'sv'), np.atleast_2d(garr[:, i*3]))

            if useindicators:
                dsf = _indicators(dsf, k, garr[:, i*3+1:i*3+3])

        if verbose:
            print(time, '\r', end='')

        epoch_data = xarray.Dataset(dsf, coords={'time': [time], 'sv': gsv})
        if len(data) == 0:
            data = epoch_data
        elif len(hdr['fields']) == 1:  # one satellite system selected, faster to process
            data = xarray.concat((data, epoch_data), dim='time')
        else:  # general case, slower for different satellite systems all together
            data = xarray.merge((data, epoch_data))

    return data
示例#17
0
    def convert_index_bit_flags(data: Dict) -> xr.Dataset:
        """
        Convert the int32 index flags to a dataset of distinct flags

        Parameters
        ----------
        data
            Dictionary of input data as returned by ``load_data``

        Returns
        -------
            Dataset of the index bit flags
        """
        flags = dict()
        flags['pmc_present'] = 0
        flags['h2o_zero_found'] = 1
        flags['h2o_slow_convergence'] = 2
        flags['h2o_ega_failure'] = 3
        flags['default_nmc_temp_errors'] = 4
        flags['ch2_aero_model_A'] = 5
        flags['ch2_aero_model_B'] = 6
        flags['ch2_new_wavelength'] = 7
        flags['incomplete_nmc_data'] = 8
        flags['mirror_model'] = 15
        flags['twomey_non_conv_rayleigh'] = 19
        flags['twomey_non_conv_386_Aero'] = 20
        flags['twomey_non_conv_452_Aero'] = 21
        flags['twomey_non_conv_525_Aero'] = 22
        flags['twomey_non_conv_1020_Aero'] = 23
        flags['twomey_non_conv_NO2'] = 24
        flags['twomey_non_conv_ozone'] = 25
        flags['no_shock_correction'] = 30

        f = dict()
        for key in flags.keys():
            f[key] = (data['InfVec'] & 2 ** flags[key]) > 0

        xr_data = []
        time = pd.to_timedelta(data['mjd'], 'D') + pd.Timestamp('1858-11-17')
        for key in f.keys():
            xr_data.append(xr.DataArray(f[key], coords=[time], dims=['time'], name=key))

        return xr.merge(xr_data)
示例#18
0
文件: _xarray.py 项目: valpasq/yatsm
def merge_data(data, merge_attrs=True):
    """ Combine multiple Datasets or DataArrays into one Dataset

    Args:
        data (dict[name, xr.DataArray or xr.Dataset]): xr.DataArray or
            xr.Dataset objects to merge
        merge_attrs (bool): Attempt to merge DataArray attributes. In order for
            these attributes to be able to merge, they must be pd.Series
            and have compatible indexes.

    Returns:
        xr.Dataset: Merged xr.DataArray objects in one xr.Dataset
    """
    datasets = [dat.to_dataset(dim='band') if isinstance(dat, xr.DataArray)
                else dat for dat in data.values()]

    ds = xr.merge(datasets, compat='minimal')

    # Overlapping but not conflicting variables can't be merged for now
    # https://github.com/pydata/xarray/issues/835
    # In meantime, check for dropped variables
    dropped = set()
    for _ds in datasets:
        dropped.update(set(_ds.data_vars).difference(set(ds.data_vars)))

    # TODO: refactor this and repeat for coords and vars
    if dropped:
        # dropped_vars = {}
        for var in dropped:
            dims = [_ds[var].dims for _ds in datasets]
            if all([dims[0] == d for d in dims[1:]]):
                dfs = [_ds.reset_coords()[var].to_series()
                       for _ds in datasets]
                # dropped_vars[var] = (dims[0], pd.concat(dfs).sort_index())
                da = xr.DataArray(pd.concat(dfs).sort_index())
                ds[var] = da
            else:
                logger.debug("Cannot restore dropped coord {} because "
                             "dimensions are inconsistent across datasets")
        # ds = ds.assign_coords(**ds.coords.merge(dropped_vars))

    return ds
示例#19
0
文件: xarray.py 项目: nbren12/gnl
def xr2mat(fields, sample_dims, feature_dims,
           scale=True):
    """Prepare list of data arrays for input to Machine Learning

    Parameters
    ----------
    fields: Dataset object
        input dataset
    sample_dims: tuple of string
        Dimensions which will be considered samples
    feature_dims: tuple of strings
        dimensions which be considered features
    scale: Bool
        center and scale the output. [default: True]

    Returns
    -------
    data: DataArray
    scaling: DataArray or None

    See Also
    --------
    gnl.data_matrix.DataMatrix : better version of this function
    """
    raise DeprecationWarning("Please switch to gnl.data_matrix.DataMatrix")
    normalize_dim = 'z'   # this could be an argument

    if not isinstance(fields, xr.Dataset):
        fields = xr.merge(fields)

    dat = fields.to_array()

    if scale:
        mu = dat.mean(sample_dims)
        V = np.sqrt(((dat-mu)**2).sum(normalize_dim)).mean(sample_dims)
        dat = (dat-mu)/V
    else:
        V = None

    stacked = dat.stack(features=('variable',)+tuple(feature_dims),
                        samples=sample_dims)
    return stacked.transpose('samples', 'features'), V
示例#20
0
 def test_merge_arrays(self):
     data = create_test_data()
     actual = xr.merge([data.var1, data.var2])
     expected = data[['var1', 'var2']]
     assert actual.identical(expected)
示例#21
0
 def test_merge_no_conflicts_preserve_attrs(self):
     data = xr.Dataset({'x': ([], 0, {'foo': 'bar'})})
     actual = xr.merge([data, data])
     assert data.identical(actual)
示例#22
0
 def test_merge_error(self):
     ds = xr.Dataset({'x': 0})
     with pytest.raises(xr.MergeError):
         xr.merge([ds, ds + 1])
示例#23
0
def run(params):
    start_time = datetime.now()

    bin_width, filter_bandwidth, theta, shift, \
        signal_field, noise_field, noise_multiplier = params

    # Get file paths
    signal_dir = '/scratch/pkittiwi/fg1p/signal_map/bin{:.2f}/' \
        'fbw{:.2f}/theta{:.1f}/shift{:d}' \
        .format(bin_width, filter_bandwidth, theta, shift)
    noise_dir = '/scratch/pkittiwi/fg1p/noise_map/bin{:.2f}/' \
        'fbw{:.2f}/theta{:.1f}/shift{:d}' \
        .format(bin_width, filter_bandwidth, theta, shift)
    output_dir = '/scratch/pkittiwi/fg1p/stats_mc/obsn{:.1f}/bin{:.2f}/' \
        'fbw{:.2f}/theta{:.1f}/shift{:d}/s{:03d}' \
        .format(noise_multiplier, bin_width, filter_bandwidth, theta,
                shift, signal_field)
    signal_file = '{:s}/signal_map_bin{:.2f}_fbw{:.2f}_' \
        'theta{:.1f}_shift{:d}_{:03d}.nc'\
        .format(signal_dir, bin_width, filter_bandwidth,
                theta, shift, signal_field)
    noise_file = '{:s}/noise_map_bin{:.2f}_fbw{:.2f}_' \
        'theta{:.1f}_shift{:d}_{:03d}.nc'\
        .format(noise_dir, bin_width, filter_bandwidth,
                theta, shift, noise_field)
    output_file = '{:s}/stats_mc_obsn{:.1f}_bin{:.2f}_fbw{:.2f}_' \
        'theta{:.1f}_shift{:d}_{:03d}_{:03d}.nc' \
        .format(output_dir, noise_multiplier, bin_width, filter_bandwidth,
                theta, shift, signal_field, noise_field)
    mask_file = '/scratch/pkittiwi/fg1p/hera331_fov_mask.nc'
    obs_dir = '/scratch/pkittiwi/fg1p/obs_map/obsn{:.1f}/bin{:.2f}/' \
        'fbw{:.2f}/theta{:.1f}/shift{:d}/s{:03d}' \
        .format(noise_multiplier, bin_width, filter_bandwidth, theta,
                shift, signal_field)
    obs_file = '{:s}/obs_map_obsn{:.1f}_bin{:.2f}_fbw{:.2f}_' \
        'theta{:.1f}_shift{:d}_{:03d}_{:03d}.nc' \
        .format(obs_dir, noise_multiplier, bin_width, filter_bandwidth,
                theta, shift, signal_field, noise_field)

    # Load data to memory and align coordinates
    with xr.open_dataarray(signal_file) as da:
        signal = da.load()
    with xr.open_dataarray(noise_file) as da:
        noise = da.load()
    with xr.open_dataarray(mask_file) as da:
        mask = da.load()
    for key, values in noise.coords.items():
        signal.coords[key] = values
        mask.coords[key] = values
    signal, noise, mask = xr.align(signal, noise, mask)

    # Make observation
    signal = signal.where(mask == 1)
    noise = noise.where(mask == 1) * noise_multiplier
    obs = signal + noise
    obs.name = 'obs'
    obs.attrs = {'signal_field': signal_field, 'noise_field': noise_field,
                 'noise_multiplier': noise_multiplier, 'bin_width': bin_width,
                 'filter_bandwidth': filter_bandwidth, 'theta': theta,
                 'shift': shift}

    # Calculate noise variance
    noise_var = noise.var(dim=['y', 'x'])
    noise_var.name = 'noise_var'
    noise_var.attrs = {
        'noise_field': noise_field, 'noise_multiplier': noise_multiplier,
        'bin_width': bin_width, 'filter_bandwidth': filter_bandwidth,
        'theta': theta, 'shift': shift
    }

    # Save observation and noise_variance
    os.makedirs(obs_dir, exist_ok=True)
    obs = xr.merge([obs, noise_var])
    obs.to_netcdf(obs_file)

    del signal
    del noise
    del mask

    # Calculate statistic
    out = get_stats(obs)
    out.attrs = {'signal_field': signal_field, 'noise_field': noise_field,
                 'noise_multiplier': noise_multiplier, 'bin_width': bin_width,
                 'filter_bandwidth': filter_bandwidth, 'theta': theta,
                 'shift': shift}

    os.makedirs(output_dir, exist_ok=True)
    out.to_netcdf(output_file)

    out.close()

    print(
        'Finish. signal_file = {:s}. noise_file = {:s}. output_file = {:s}.'
        'Time spent {:.5f} sec.'
        .format(signal_file, noise_file, output_file,
                (datetime.now() - start_time).total_seconds())
    )
示例#24
0
 def test_merge_alignment_error(self):
     ds = xr.Dataset(coords={'x': [1, 2]})
     other = xr.Dataset(coords={'x': [2, 3]})
     with raises_regex(ValueError, 'indexes .* not equal'):
         xr.merge([ds, other], join='exact')
示例#25
0
        concat_dim='ensemble',
        combine='nested',
        compat=
        'identical',  # seems to be strictest setting...not sure if necessary
        parallel="True",
        join="exact"  # another strict selection...
    )

    # Add ensemble as a dimension
    ds = ds.assign_coords({'ensemble': np.arange(1, len(mnum) + 1, 1)})

    # Merge variables to Dataset (assuming they have the same coordinates)
    if v == 0:
        dsall = ds.copy()
    else:
        dsall = xr.merge([dsall, ds])

#%% Get the DJFM and Regional cuts for EOF calculation
cuttime = time.time()

# Read in the data # [Ens x Time x Lat d Lon]
pslglo = dsall.PSL.values / 100  # divide by 100 to conver to hPa
lat = dsall.lat.values
lon = dsall.lon.values
times = dsall.time.values

# Get dimensions
nlon = len(lon)
nlat = len(lat)
ntime = len(times)
nens = len(mnum)
示例#26
0
文件: obs3.py 项目: VOMAY/georinex
def rinexobs3(fn: Union[TextIO, str, Path],
              use: Sequence[str] = None,
              tlim: Tuple[datetime, datetime] = None,
              useindicators: bool = False,
              meas: Sequence[str] = None,
              verbose: bool = False,
              *,
              interval: Union[float, int, timedelta] = None) -> xr.Dataset:
    """
    process RINEX 3 OBS data

    fn: RINEX OBS 3 filename
    use: 'G'  or ['G', 'R'] or similar

    tlim: read between these time bounds
    useindicators: SSI, LLI are output
    meas:  'L1C'  or  ['L1C', 'C1C'] or similar

    interval: allows decimating file read by time e.g. every 5 seconds.
                Useful to speed up reading of very large RINEX files
    """

    # %% Check input arguments and initialise
    interval = check_time_interval(interval)

    if isinstance(use, str):
        use = [use]

    if isinstance(meas, str):
        meas = [meas]

    if not use[0].strip():
        use = None

    if not meas[0].strip():
        meas = None

    if tlim is not None and not isinstance(tlim[0], datetime):
        raise TypeError('time bounds are specified as datetime.datetime')

    last_epoch = None

    # %% Parsing loop
    with opener(fn) as f:
        try:
            # Read header into HeaderClass instance
            hdr = obsheader3(f)

            # filter signales based on selection via input arguments
            selObs, selInd = filterObs3(hdr, use, meas)

            # Get set of all signals of all constellations
            signalUnion = sorted(
                set([
                    signal for constSigList in selObs.values()
                    for signal in constSigList
                ]))

            # Allocate Main internal data buffer
            obsBuf = {}
            for signal in signalUnion:
                if useindicators:
                    obsBuf[signal] = {
                        'time': [],
                        'const': [],
                        'prn': [],
                        'val': [],
                        'ssi': [],
                        'lli': []
                    }
                else:
                    obsBuf[signal] = {
                        'time': [],
                        'const': [],
                        'prn': [],
                        'val': []
                    }

        except KeyError:
            return xr.Dataset()

        # %%    Process OBS file
        for ln in f:
            # Check for next epoch
            if not ln.startswith('>'):
                break

            # %%Process Epoch Record line
            try:
                time, in_range = _timeobs(ln, tlim, last_epoch, interval)
            except ValueError:  # garbage between header and RINEX data
                logging.debug(
                    f'garbage detected in {fn}, trying to parse at next time step'
                )
                continue

            # Number of visible satellites this epoch
            nSv = int(ln[33:35])

            # Check if epoch is in selected interval
            if in_range == -1:
                for _ in range(nSv):
                    next(f)
                continue
            if in_range == 1:
                break
            last_epoch = time

            if verbose:
                print(time, end="\r")

            # %% Process observation lines
            obsEpoch = {}
            # Read nSv lines and extract selected data
            for _, epochLine in zip(range(nSv), f):
                # Check if this line starts with an expected constellatin letter
                if epochLine[0] not in hdr.obsType.keys():
                    raise KeyError(f'Unexpected line found in RINEX file')

                obsEpoch = _epoch(obsEpoch, selObs, selInd, epochLine,
                                  useindicators)

            # Store selected data of epoch in internal buffer obsBuf
            for signal in obsEpoch:
                obsBuf[signal]['time'].append(time)
                obsBuf[signal]['const'].append(obsEpoch[signal]['const'])
                obsBuf[signal]['prn'].append(obsEpoch[signal]['prn'])
                obsBuf[signal]['val'].append(obsEpoch[signal]['val'])
                if useindicators:
                    obsBuf[signal]['lli'].append(obsEpoch[signal]['lli'])
                    obsBuf[signal]['ssi'].append(obsEpoch[signal]['ssi'])

    # %% Process OBS file Convert internval buffer (dict) to output format (xarray.DataArray)
    # First generate one DataArray per signal, then merge them together
    data = []
    for signal in obsBuf:
        # Get all times of this signal
        signalTime = obsBuf[signal]['time']
        # Get all constalltions with this signal
        signalConst = np.sort(
            np.array(
                list(
                    set([
                        const for constEpochList in obsBuf[signal]['const']
                        for const in constEpochList
                    ]))))
        # Get all satellites of this constellations with this signal
        signalPrn = np.sort(
            np.array(
                list(
                    set([
                        prn for prnEpochList in obsBuf[signal]['prn']
                        for prn in prnEpochList
                    ]))))
        # Allocate array of ovservations with three dimensions
        signalVal = np.empty(
            (len(signalTime), len(signalConst), len(signalPrn)))

        # Geneate DataArray and append to list
        data.append(
            _gen_array(signalTime, signalConst, signalPrn,
                       obsBuf[signal]['const'], obsBuf[signal]['prn'],
                       obsBuf[signal]['val'], signalVal, signal))
        if useindicators:
            data.append(
                _gen_array(signalTime, signalConst, signalPrn,
                           obsBuf[signal]['const'], obsBuf[signal]['prn'],
                           obsBuf[signal]['lli'], signalVal, signal + '-lli'))
            data.append(
                _gen_array(signalTime, signalConst, signalPrn,
                           obsBuf[signal]['const'], obsBuf[signal]['prn'],
                           obsBuf[signal]['ssi'], signalVal, signal + '-ssi'))

    # Merge DataArray
    data = xr.merge(data)

    # Add Attributes
    data.attrs['version'] = hdr.version
    hdr.cInterval(data.time)
    data.attrs['interval'] = hdr.interval
    data.attrs['rinexType'] = hdr.rinexType
    if hasattr(hdr, 'position'):
        data.attrs['position'] = hdr.position
    if hasattr(hdr, 'positionGeodetic'):
        data.attrs['positionGeodetic'] = hdr.positionGeodetic
    data.attrs['timeSystem'] = hdr.timeSystem
    if isinstance(fn, Path):
        data.attrs['filename'] = fn.name
    if hasattr(hdr, 'tFirst'):
        data.attrs['tFirst'] = hdr.tFirst
    if hasattr(hdr, 'tLast'):
        data.attrs['tLast'] = hdr.tLast

    return data
示例#27
0
def collect_inst_model_pairs(start=None,
                             stop=None,
                             tinc=None,
                             inst=None,
                             user=None,
                             password=None,
                             model_files=None,
                             model_load_rout=None,
                             inst_lon_name=None,
                             mod_lon_name=None,
                             inst_name=[],
                             mod_name=[],
                             mod_datetime_name=None,
                             mod_time_name=None,
                             mod_units=[],
                             sel_name=None,
                             method='linear',
                             model_label='model',
                             inst_clean_rout=None,
                             comp_clean='clean'):
    """Pair instrument and model data, applying data cleaning after finding the
    times and locations where the instrument and model align

    Parameters
    ----------
    start : dt.datetime
        Starting datetime
    stop : dt.datetime
        Ending datetime
    tinc : dt.timedelta
        Time incriment for model files
    inst : pysat.Instrument instance
        instrument object for which modelled data will be extracted
    user : string
        User name (needed for some data downloads)
    password : string
        Password (needed for some data downloads)
    model_files : string
        string format that will construct the desired model filename from a
        datetime object
    model_load_rout : routine
        Routine to load model data into an xarray using filename and datetime
        as input
    inst_lon_name : string
        variable name for instrument longitude
    mod_lon_name : string
        variable name for model longitude
    inst_name : list of strings
        list of names of the data series to use for determing instrument
        location
    mod_name : list of strings
        list of names of the data series to use for determing model locations
        in the same order as inst_name.  These must make up a regular grid.
    mod_datetime_name : string
        Name of the data series in the model Dataset containing datetime info
    mod_time_name : string
        Name of the time coordinate in the model Dataset
    mod_units : list of strings
        units for each of the mod_name location attributes.  Currently
        supports: rad/radian(s), deg/degree(s), h/hr(s)/hour(s), m, km, and cm
    sel_name : list of strings or NoneType
        list of names of modelled data indices to append to instrument object,
        or None to append all modelled data (default=None)
    method : string
        Interpolation method.  Supported are 'linear', 'nearest', and
        'splinef2d'.  The last is only supported for 2D data and is not
        recommended here.  (default='linear')
    model_label : string
        name of model, used to identify interpolated data values in instrument
        (default="model")
    inst_clean_rout : routine
        Routine to clean the instrument data
    comp_clean : string
        Clean level for the comparison data ('clean', 'dusty', 'dirty', 'none')
        (default='clean')

    Returns
    -------
    matched_inst : pysat.Instrument instance
        instrument object and paired modelled data

    """
    from os import path
    import pysat

    matched_inst = None

    # Test the input
    if start is None or stop is None:
        raise ValueError('Must provide start and end time for comparison')

    if inst is None:
        raise ValueError('Must provide a pysat instrument object')

    if model_files is None:
        raise ValueError('Must provide list of modelled data')

    if model_load_rout is None:
        raise ValueError('Need routine to load modelled data')

    if mod_datetime_name is None:
        raise ValueError('Need time coordinate name for model datasets')

    if mod_time_name is None:
        raise ValueError('Need time coordinate name for model datasets')

    if len(inst_name) == 0:
        estr = 'Must provide instrument location attribute names as a list'
        raise ValueError(estr)

    if len(inst_name) != len(mod_name):
        estr = 'Must provide the same number of instrument and model '
        estr += 'location attribute names as a list'
        raise ValueError(estr)

    if len(mod_name) != len(mod_units):
        raise ValueError(
            'Must provide units for each model location attribute')

    if inst_clean_rout is None:
        raise ValueError('Need routine to clean the instrument data')

    # Download the instrument data, if needed
    # Could use some improvement, for not re-downloading times that you already
    # have
    if (stop - start).days != len(inst.files[start:stop]):
        inst.download(start=start, stop=stop, user=user, password=password)

    # Cycle through the times, loading the model and instrument data as needed
    istart = start
    while start < stop:
        mod_file = start.strftime(model_files)

        if path.isfile(mod_file):
            mdata = model_load_rout(mod_file, start)
            lon_high = float(mdata.coords[mod_lon_name].max())
            lon_low = float(mdata.coords[mod_lon_name].min())
        else:
            mdata = None

        if mdata is not None:
            # Load the instrument data, if needed
            if inst.empty or inst.index[-1] < istart:
                inst.custom.add(pysat.utils.update_longitude,
                                'modify',
                                low=lon_low,
                                lon_name=inst_lon_name,
                                high=lon_high)
                inst.load(date=istart)

            if not inst.empty and inst.index[0] >= istart:
                added_names = extract_modelled_observations(inst=inst, \
                                model=mdata, inst_name=inst_name,
                                                            mod_name=mod_name, \
                                mod_datetime_name=mod_datetime_name, \
                                mod_time_name=mod_time_name, \
                                mod_units=mod_units, sel_name=sel_name, \
                                method=method, model_label=model_label)

                if len(added_names) > 0:
                    # Clean the instrument data
                    inst.clean_level = comp_clean
                    inst_clean_rout(inst)

                    im = list()
                    for aname in added_names:
                        # Determine the number of good points
                        if inst.pandas_format:
                            imnew = np.where(~np.isnan(inst[aname]))
                        else:
                            imnew = np.where(~np.isnan(inst[aname].values))

                        # Some data types are higher dimensions than others,
                        # make sure we end up choosing a high dimension one
                        # so that we don't accidently throw away paired data
                        if len(im) == 0 or len(im[0]) < len(imnew[0]):
                            im = imnew

                    # If the data is 1D, save it as a list instead of a tuple
                    if len(im) == 1:
                        im = im[0]
                    else:
                        im = {
                            kk: im[i]
                            for i, kk in enumerate(inst.data.coords.keys())
                        }

                    # Save the clean, matched data
                    if matched_inst is None:
                        matched_inst = pysat.Instrument
                        matched_inst.meta = inst.meta
                        if inst.pandas_format:
                            matched_inst.data = inst.data.iloc[im]
                        else:
                            matched_inst.data = inst.data.isel(im)
                    else:
                        if inst.pandas_format:
                            matched_inst.data = matched_inst.data.append( \
                                                        inst.data.iloc[im])
                        else:
                            matched_inst.data = xr.merge(
                                matched_inst.data, inst.data.isel(im))

                    # Reset the clean flag
                    inst.clean_level = 'none'

        # Cycle the times
        if tinc.total_seconds() <= 86400.0:
            start += tinc
            if start + tinc > istart + dt.timedelta(days=1):
                istart += dt.timedelta(days=1)
        else:
            if start + tinc >= istart + dt.timedelta(days=1):
                istart += dt.timedelta(days=1)
            if istart >= start + tinc:
                start += tinc

    # Recast as xarray and add units
    if matched_inst is not None:
        if inst.pandas_format:
            matched_inst.data = matched_inst.data.to_xarray()
        for im in inst.meta.data.units.keys():
            if im in matched_inst.data.data_vars.keys():
                matched_inst.data.data_vars[im].attrs['units'] = \
                    inst.meta.data.units[im]

    return matched_inst
示例#28
0
b_lat = trans_lat1[0] - lat0

# Create transects
path = '/g/data/w40/esh563/goulburn_NT/ASCAT/ASCAT_goulburn_2012-2014_12.nc'
ASCAT = xr.open_dataset(path)

proj = ASCAT.u_pert_mean * b_lon + ASCAT.v_pert_mean * b_lat
proj = proj / np.sqrt(b_lon**2 + b_lat**2)

ASCAT_tran = ta.calc_transects(proj, trans_lon0, trans_lat0, trans_lon1,
                               trans_lat1, n_points, n_trans)
ASCAT_p_value_tran = ta.calc_transects(ASCAT.p_value_mean, trans_lon0,
                                       trans_lat0, trans_lon1, trans_lat1,
                                       n_points, n_trans)

ASCAT_tran = ASCAT_tran.assign_coords(coastal_axis=coast_distances)
ASCAT_tran = ASCAT_tran.assign_coords(transect_axis=tran_distances)
ASCAT_tran = ASCAT_tran.rename('wind_proj')

ASCAT_p_value_tran = ASCAT_p_value_tran.assign_coords(
    coastal_axis=coast_distances)
ASCAT_p_value_tran = ASCAT_p_value_tran.assign_coords(
    transect_axis=tran_distances)
ASCAT_p_value_tran = ASCAT_p_value_tran.rename('p_value')

ASCAT_tran = xr.merge((ASCAT_tran, ASCAT_p_value_tran))

save_path_ASCAT = '/g/data/w40/esh563/goulburn_NT/transect_means/ASCAT_goulburn_2012-2014_12.nc'

ASCAT_tran.to_netcdf(path=save_path_ASCAT, mode='w', format='NETCDF4')
示例#29
0
def main():

    # Define some constants
    # Number of input ice layers
    nilyr1 = 3

    # Number of CICE5 ice layers
    nilyr2 = 7

    # Number of snow layers in each ice category
    nslyr = 1

    # Number of ice categories in CICE5
    ncat = 5

    # Missing value
    missing = 9.96920996838687e+36

    # Category boundaries
    c1 = 0.6445
    c2 = 1.3914
    c3 = 2.4702
    c4 = 4.5673
    cvals = [c1, c2, c3, c4]

    # Salinity profile constants
    saltmax = 3.2  # Maximum salinity at ice base
    nsal = 0.407  # Profile constant
    msal = 0.573  # Profile constant

    # Density values
    rhoi = 917.0  # Density of ice
    rhos = 330.0  # Density of snow
    cp_ice = 2106.0  # Specific heat of fresh ice
    cp_ocn = 4218.0  # Specific heat of sea water
    Lfresh = 3.34e5  # Latent heat of melting fresh ice

    # Define intervals for interpolation to CICE5
    rstart = 0.5 * (1 / nilyr2)
    rend = 1 - rstart
    tlevs = np.linspace(rstart, rend, nilyr2)
    nilyr = tlevs

    # Define data directory
    print("Get environment variables.")
    datadir = os.environ.get('UFSDATA_DIR')
    postdir = os.environ.get('UFSPOST_DIR')
    ora_file = os.environ.get('ORAS5_FILE')
    print("End get environment variables.")

    # Get list of files to process
    print("Begin open files.")
    flist = get_filelist(datadir, "*" + ora_file + ".nc")
    rlist = get_fileroot(flist)
    dsets = open_filelist(flist, __file__)
    print("End open files.")
    print(flist)
    print(dsets)

    # Iterate over open files
    for i, dsin in enumerate(dsets):
        print(flist[i])

        # Rename input variables for consistency
        part_size = 1. - dsin.frld.squeeze()
        h_ice = dsin.hicif.squeeze()
        h_sno = dsin.hsnif.squeeze()
        t_surf = dsin.sist.squeeze()
        t1 = dsin.tbif1.squeeze()
        t2 = dsin.tbif2.squeeze()
        t3 = dsin.tbif3.squeeze()
        #
        part_size.name = "part_size"
        h_ice.name = "h_ice"
        h_sno.name = "h_sno"
        t_surf.name = "t_surf"
        t1.name = "t1"
        t2.name = "t2"
        t3.name = "t3"

        # Get dimensions if input data
        ndims1 = part_size.shape
        nj = part_size.nj.size
        ni = part_size.ni.size

        # Initalize dataarrays
        dummy = xr.DataArray(np.zeros((nj, ni), dtype=np.double),
                             coords={
                                 'nj': part_size.nj,
                                 "ni": part_size.ni
                             },
                             dims=['nj', 'ni'],
                             name="dummy var")
        iceumask = xr.DataArray(np.full((nj, ni), 1, dtype=np.double),
                                coords={
                                    'nj': part_size.nj,
                                    "ni": part_size.ni
                                },
                                dims=['nj', 'ni'],
                                name="iceumask")
        aicen = xr.DataArray(np.zeros((ncat, nj, ni), dtype=np.double),
                             coords={
                                 'ncat': range(0, ncat),
                                 'nj': part_size.nj,
                                 "ni": part_size.ni
                             },
                             dims=['ncat', 'nj', 'ni'],
                             name="aicen")
        vicen = xr.DataArray(np.zeros((ncat, nj, ni), dtype=np.double),
                             coords={
                                 'ncat': range(0, ncat),
                                 'nj': part_size.nj,
                                 "ni": part_size.ni
                             },
                             dims=['ncat', 'nj', 'ni'],
                             name="vicen")
        vsnon = xr.DataArray(np.zeros((ncat, nj, ni), dtype=np.double),
                             coords={
                                 'ncat': range(0, ncat),
                                 'nj': part_size.nj,
                                 "ni": part_size.ni
                             },
                             dims=['ncat', 'nj', 'ni'],
                             name="vsnon")
        Tsfcn = xr.DataArray(np.zeros((ncat, nj, ni), dtype=np.double),
                             coords={
                                 'ncat': range(0, ncat),
                                 'nj': part_size.nj,
                                 "ni": part_size.ni
                             },
                             dims=['ncat', 'nj', 'ni'],
                             name="Tsfcn")
        tice = xr.DataArray(np.zeros((nilyr1, ncat, nj, ni), dtype=np.double),
                            coords={
                                'nilyr': np.linspace(0, 1, 3),
                                'ncat': range(0, ncat),
                                'nj': part_size.nj,
                                "ni": part_size.ni
                            },
                            dims=['nilyr', 'ncat', 'nj', 'ni'],
                            name="tice")
        Tin = xr.DataArray(np.zeros((nilyr2, ncat, nj, ni), dtype=np.double),
                           coords={
                               'nilyr': tlevs,
                               'ncat': range(0, ncat),
                               'nj': part_size.nj,
                               "ni": part_size.ni
                           },
                           dims=['nilyr', 'ncat', 'nj', 'ni'],
                           name="Tin")
        sice = xr.DataArray(np.zeros((nilyr2, ncat, nj, ni), dtype=np.double),
                            coords={
                                'nilyr': tlevs,
                                'ncat': range(0, ncat),
                                'nj': part_size.nj,
                                "ni": part_size.ni
                            },
                            dims=['nilyr', 'ncat', 'nj', 'ni'],
                            name="sice")
        qice = xr.DataArray(np.zeros((nilyr2, ncat, nj, ni), dtype=np.double),
                            coords={
                                'nilyr': tlevs,
                                'ncat': range(0, ncat),
                                'nj': part_size.nj,
                                "ni": part_size.ni
                            },
                            dims=['nilyr', 'ncat', 'nj', 'ni'],
                            name="qice")
        qsno = xr.DataArray(np.zeros((nslyr, ncat, nj, ni), dtype=np.double),
                            coords={
                                'nslyr': range(nslyr),
                                'ncat': range(0, ncat),
                                'nj': part_size.nj,
                                "ni": part_size.ni
                            },
                            dims=['nslyr', 'ncat', 'nj', 'ni'],
                            name="qsno")
        print(vicen.shape)
        print(h_ice.shape)

        # Set ice fraction to zero where values are missing, based on surface temperature
        ice_frac = part_size.where(h_ice > 0., other=0)
        ice_frac.name = "ice_frac"

        # Calculate ice fraction per category
        aicen = ice_category_brdcst(ice_frac, h_ice, cvals)
        aicen.name = 'aicen'

        # Restore missing metadata
        aicen['nj'] = part_size.nj
        aicen['ni'] = part_size.ni
        aicen['ncat'] = range(0, ncat)

        # Calculate ice mask
        iceumask = iceumask.where(aicen.sum(dim='ncat') > 1e-11, other=0)

        # Calculate ice volume per category
        for k in range(ncat):
            vicen[k, :, :] = h_ice.where(h_ice > 0, other=0) * aicen[k, :, :]
        vicen.name = "vicen"

        # Calculate snow volume per category
        for k in range(ncat):
            vsnon[k, :, :] = h_sno.where(h_sno > 0, other=0) * aicen[k, :, :]
        vsnon.name = "vsnon"

        # Calculate Surface temperature per category
        # Missing value for t_surf is 0, convert to Kelvin to avoid excessive negative values later
        t_surf = t_surf.where(t_surf != 0, other=273.15)
        Tsfcn = ice_category_brdcst(t_surf - 273.15, h_ice, cvals)
        Tsfcn = Tsfcn.where(Tsfcn < 0, other=0)
        Tsfcn.name = "Tsfcn"

        # Calculate ice layer temperature per category and combine
        tice[0, :, :, :] = ice_category_brdcst(
            t1.where(t1 != 0, other=273.15) - 273.15, h_ice, cvals)
        tice[1, :, :, :] = ice_category_brdcst(
            t2.where(t2 != 0, other=273.15) - 273.15, h_ice, cvals)
        tice[2, :, :, :] = ice_category_brdcst(
            t3.where(t3 != 0, other=273.15) - 273.15, h_ice, cvals)

        # Linearly interpolate from ORAS5 layers to CICE5
        Tin = tice.interp(nilyr=tlevs)
        Tin.name = "Tin"
        Tin = Tin.where(Tin < 0, other=0)
        print(Tin)

        # Create salinity profile
        zn = np.asarray([(k + 1 - 0.5) / nilyr2 for k in range(nilyr2)])
        print((np.pi * zn**(nsal / (msal + zn))))
        salinz = 0.5 * saltmax * (1 - np.cos(np.pi * zn**(nsal / (msal + zn))))
        print(salinz)
        for k in range(nilyr2):
            sice[k, :, :, :] = salinz[k]

# Determine freezing point depression
        Tmltz = salinz / (-18.48 + (0.01848 * salinz))

        # Calculate ice layer enthalpy
        # Don't allow ice temperature to exceed melting temperature
        for k in range(nilyr2):
            Tin[k, :, :, :] = Tin[k, :, :, :].where(Tin[k, :, :, :] < Tmltz[k],
                                                    other=Tmltz[k])
            qice[k, :, :, :] = rhoi * cp_ice * Tin[k, :, :, :] - rhoi * Lfresh
            qice[k, :, :, :] = qice[k, :, :, :].where(vicen > 0, other=0)


# Calculate snow layer enthalpy
        qsno[0, :, :, :] = -rhos * (Lfresh - cp_ice * Tsfcn)
        qsno[0, :, :, :] = qsno[0, :, :, :].where(vsnon > 0,
                                                  other=-rhos * Lfresh)
        Trecon = (Lfresh + qsno / rhos) / cp_ice
        Trecon = Trecon.where(vsnon > 0, 0)
        Trecon.name = 'Trecon'
        Trecon.to_netcdf("test.nc")

        # Write output in expected format
        qsno.to_netcdf("qsno_test.nc")

        # Create list of variables initialized to zero
        dlist = [
            'uvel', 'vvel', 'scale_factor', 'coszen', 'swvdr', 'swvdf',
            'swidr', 'swidf', 'strocnxT', 'strocnyT', 'stressp_1', 'stressp_2',
            'stressp_3', 'stressp_4', 'stressm_1', 'stressm_2', 'stressm_3',
            'stressm_4', 'stress12_1', 'stress12_2', 'stress12_3',
            'stress12_4', 'frz_onset'
        ]
        dout = xr.merge([aicen, vicen, vsnon, Tsfcn, iceumask, Tin])
        dout['qsno001'] = qsno[0, :, :, :].squeeze()
        print(np.linspace(0, 4, 1))
        dout['ncat'] = np.linspace(0, 4, 5)
        dout['nilyr'] = tlevs
        for vname in dlist:
            dout[vname] = dummy
        for k in range(nilyr2):
            dout['qice00' + str(k + 1)] = qice[k, :, :, :]
            dout['sice00' + str(k + 1)] = sice[k, :, :, :]
        dout.fillna(missing)
        for vname in dout.data_vars:
            dout[vname].encoding['_FillValue'] = missing
        dout.to_netcdf("cice_20120101_test.qsno.nc", format='NETCDF3_CLASSIC')
示例#30
0
def lai_filter(filepath, out, std):

    import os
    import xarray as xr

    # import collections
    # filepath = 'P:\\nasa_above\\working\\modis_analyses\\MyTest.nc'
    fname = os.path.basename(filepath)
    outname_lai = filepath.replace(
        fname,
        fname.replace(".nc", "") + "_lai_filtered.nc")

    # Read the nc file as xarray DataSet
    ds = xr.open_dataset(filepath)

    # Read the layers from the dataset (DataArray)
    FparLai_QC = ds["FparLai_QC"]
    lai = ds["Lai_500m"]

    # ------------------ Filtering -------------------------------------------
    print("\n\nStarted filtering LAI: takes some time")

    # These values come from the LAI manual for quality control
    # returns the quality value (0,2,32 and 34) and nan for every other quality
    # https://lpdaac.usgs.gov/documents/2/mod15_user_guide.pdf
    lai_flag = FparLai_QC.where((FparLai_QC.values == 0)
                                | (FparLai_QC.values == 2)
                                | (FparLai_QC.values == 32)
                                | (FparLai_QC.values == 34))

    # Convert it Boolean
    lai_flag = lai_flag >= 0
    # Convert Boolean to zero (bad quality) and one (high quality)
    lai_flag = lai_flag.astype(int)
    lai_tmp = lai_flag * lai
    # replace all zeros with nan values
    lai_final = lai_tmp.where(lai_tmp.values != 0)
    lai_final = lai_final.rename("LAI")

    # Did use asked for standard deviation of lai too?
    if std == True:
        lai_std = ds["LaiStdDev_500m"]
        print("\n\nStarted filtering LAI Standard Deviation: takes some time")
        lai_std_tmp = lai_flag * lai_std
        lai_std_final = lai_std_tmp.where(lai_std_tmp.values != 0)
        lai_std_final = lai_std_final.rename("LAI_STD")
        lai_dataset = xr.merge([lai_final, lai_std_final])
        outname_std_lai = filepath.replace(
            fname,
            fname.replace(".nc", "") + "_lai_dataset.nc")

    # -------------------- OUTPUTS ------------------------------------------
    print("\n\n Figuring out the outputs:")
    if std == True and out == True:
        print('\n   ---writing the lai and lai_std as a "dataset" on the disk')
        lai_dataset.to_netcdf(outname_std_lai)
        return lai_dataset
    elif std == False and out == True:
        print("\n   ---wirintg just lai to the disk (no STD) and return lai")
        lai_final.to_netcdf(outname_lai)
        return lai_final
    elif std == True and out == False:
        print('\n   ---return lai and lai_std as a "dataset"')
        return lai_dataset
    elif std == False and out == False:
        print(
            "\n   ---return lai (not standard deviation) and no writing on the disk"
        )
        return lai_final
                print(str(counter) + ' geändert')
            if h in [6, 12, 18, 0]:
                tmp[counter, :, :] = tp[counter, 5, :, :] - tp[counter - 1, 4, :, :]
                print(str(counter) + ' geändert')

        # make a double sized array and then merge (overlay)
        data = numpy.ndarray([tmp.shape[0], 2 * tmp.shape[1] - 1, 2 * tmp.shape[2] - 1])
        data[:] = None
        coords_y = numpy.arange(min(ds.__getitem__('y').values), max(ds.__getitem__('y').values) + 0.5,
                                0.5)  # make coords in 0.5 instead of 1 step --> double size
        coords_x = numpy.arange(min(ds.__getitem__('x').values), max(ds.__getitem__('x').values) + 0.5, 0.5)
        tmp_x2 = xr.DataArray(data, dims={'time': 743, 'y': 2 * 780, 'x': 2 * 724},
                              coords=[ds.__getitem__('time').values, coords_y, coords_x])
        tmp_x2 = tmp_x2.to_dataset(name = 'tp')
        tmp = tmp.to_dataset(name= 'tp')
        tmp_merge = xr.merge([tmp_x2, tmp], compat='no_conflicts')
        tp_merge = tmp_merge.__getitem__('tp')

        del ( tmp_x2, tmp_merge, data, ds)
        # for better understanding here a test example
        # Working 2D Example of filling the gaps --> important: fill nan with 0
        test = tp_merge[0, :, :]
        test = test.fillna(0)
        # indexing: start:stop:step
        test[1:test.shape[0], :] = test[0:test.shape[0] - 1, :].values + test[1:test.shape[0], :].values  # down
        test[1:test.shape[0]:2, 1:test.shape[1]:2] = test[0:test.shape[0] - 1:2, 0:test.shape[1] - 1:2].values + test[1:
                                                                                                                      test.shape[
                                                                                                                          0]:2,
                                                                                                                 1:
                                                                                                                 test.shape[
                                                                                                                     1]:2].values  # down + right
    hia_5cod   = calc_hia_gemm_5cod(pm25, pop, dict_ages, dict_bm, dict_gemm)

    hia_ncdlri_list = []
    hia_5cod_list = []

    for name, array in hia_ncdlri.items():
        ds = xr.DataArray(array, dims=('lat', 'lon'), coords={'lat': lat, 'lon': lon}).to_dataset(name=name)
        hia_ncdlri_list.append(ds)


    for name, array in hia_5cod.items():
        ds = xr.DataArray(array, dims=('lat', 'lon'), coords={'lat': lat, 'lon': lon}).to_dataset(name=name)
        hia_5cod_list.append(ds)


    ds_ncdlri = xr.merge(hia_ncdlri_list)
    ds_5cod = xr.merge(hia_5cod_list)

    ds_ncdlri.to_netcdf(results_path + 'hia_ncdlri_' + sim + '.nc')
    ds_5cod.to_netcdf(results_path + 'hia_5cod_' + sim + '.nc')

    gdf = gpd.read_file(data_path + 'gadm28_adm0.shp')
    country_list = list(gdf.ID_0.values)

    df_country_hia_ncdlri = shapefile_hia(hia_ncdlri, 'ncdlri', 'country', data_path + 'gadm28_adm0.shp', results_path + '', lat, lon, region_list=country_list)
    df_country_hia_5cod   = shapefile_hia(hia_5cod, '5cod', 'country', data_path + 'gadm28_adm0.shp', results_path + '', lat, lon, region_list=country_list)

    df_country_hia_ncdlri.to_csv(results_path + 'df_country_hia_ncdlri_' + sim + '.csv')
    df_country_hia_5cod.to_csv(results_path + 'df_country_hia_5cod_' + sim + '.csv')

示例#33
0
 def test_merge_dataarray_unnamed(self):
     data = xr.DataArray([1, 2], dims='x')
     with raises_regex(
             ValueError, 'without providing an explicit name'):
         xr.merge([data])
示例#34
0
def cutout(
    od,
    varList=None,
    YRange=None,
    XRange=None,
    add_Hbdr=False,
    mask_outside=False,
    ZRange=None,
    add_Vbdr=False,
    timeRange=None,
    timeFreq=None,
    sampMethod="snapshot",
    dropAxes=False,
    transformation=False,
    centered="Atlantic",
    chunks=None,
):
    """
    Cutout the original dataset in space and time
    preserving the original grid structure.

    Parameters
    ----------
    od: OceanDataset
        oceandataset to subsample
    varList: 1D array_like, str, or None
        List of variables (strings).
    YRange: 1D array_like, scalar, or None
        Y axis limits (e.g., latitudes).
        If len(YRange)>2, max and min values are used.
    XRange: 1D array_like, scalar, or None
        X axis limits (e.g., longitudes).
        If len(XRange)>2, max and min values are used.
    add_Hbdr: bool, scal
        If scalar, add and subtract `add_Hbdr` to the the horizontal range.
        of the horizontal ranges.
        If True, automatically estimate add_Hbdr.
        If False, add_Hbdr is set to zero.
    mask_outside: bool
        If True, set all values in areas outside specified (Y,X)ranges to NaNs.
        (Useful for curvilinear grids).
    ZRange: 1D array_like, scalar, or None
        Z axis limits.
        If len(ZRange)>2, max and min values are used.
    add_Vbdr: bool, scal
        If scalar, add and subtract `add_Vbdr` to the the vertical range.
        If True, automatically estimate add_Vbdr.
        If False, add_Vbdr is set to zero.
    timeRange: 1D array_like, numpy.ScalarType, or None
        time axis limits.
        If len(timeRange)>2, max and min values are used.
    timeFreq: str or None
        Time frequency.
        Available optionts are pandas Offset Aliases (e.g., '6H'):
        http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases
    sampMethod: {'snapshot', 'mean'}
        Downsampling method (only if timeFreq is not None).
    dropAxes: 1D array_like, str, or bool
        List of axes to remove from Grid object.
        if one point only is in the range.
        If True, set dropAxes=od.grid_coords.
        If False, preserve original grid.
    transformation: str, or bool
        Lists the transformation of the llcgrid into a new one in which face
        is no longer a dimension. Default is `False`. If `True`, need to
        define how data will be centered
    centered: str, or bool
        default is `Atlantic`, and other options is `Pacific`. This refers
        to which ocean appears centered on the data.

    Returns
    -------
    od: OceanDataset
        Subsampled oceandataset

    Notes
    -----
    If any of the horizontal ranges is not None,
    the horizontal dimensions of the cutout will have
    len(Xp1)>len(X) and len(Yp1)>len(Y)
    even if the original oceandataset had
    len(Xp1)==len(X) or len(Yp1)==len(Y).
    """

    # Checks
    unsupported_dims = ["mooring", "particle", "station"]
    check1 = XRange is not None or YRange is not None
    if check1 and any([dim in unsupported_dims for dim in od._ds.dims]):
        _warnings.warn(
            "\nHorizontal cutout not supported" "for moorings, surveys, and particles",
            stacklevel=2,
        )
        XRange = None
        YRange = None

    _check_instance(
        {
            "od": od,
            "add_Hbdr": add_Hbdr,
            "mask_outside": mask_outside,
            "timeFreq": timeFreq,
        },
        {
            "od": "oceanspy.OceanDataset",
            "add_Hbdr": "(float, int, bool)",
            "mask_outside": "bool",
            "timeFreq": ["type(None)", "str"],
        },
    )
    varList = _check_list_of_string(varList, "varList")
    YRange = _check_range(od, YRange, "YRange")
    XRange = _check_range(od, XRange, "XRange")
    ZRange = _check_range(od, ZRange, "ZRange")
    timeRange = _check_range(od, timeRange, "timeRange")
    sampMethod_list = ["snapshot", "mean"]

    if sampMethod not in sampMethod_list:
        raise ValueError(
            "`sampMethod` [{}] is not supported."
            "\nAvailable options: {}"
            "".format(sampMethod, sampMethod_list)
        )

    if not isinstance(dropAxes, bool):
        dropAxes = _check_list_of_string(dropAxes, "dropAxes")
        axes_warn = [axis for axis in dropAxes if axis not in od.grid_coords]
        if len(axes_warn) != 0:
            _warnings.warn(
                "\n{} are not axes of the oceandataset" "".format(axes_warn),
                stacklevel=2,
            )
            dropAxes = list(set(dropAxes) - set(axes_warn))
        dropAxes = {d: od.grid_coords[d] for d in dropAxes}
    elif dropAxes is True:
        dropAxes = od.grid_coords
        if YRange is None:
            dropAxes.pop("Y", None)
        if XRange is None:
            dropAxes.pop("X", None)
        if ZRange is None:
            dropAxes.pop("Z", None)
        if timeRange is None:
            dropAxes.pop("time", None)
    else:
        dropAxes = {}

    # Message
    print("Cutting out the oceandataset.")

    # Copy
    od = _copy.copy(od)

    # list for coord variables
    co_list = [var for var in od._ds.coords if var not in od._ds.dims]
    # Drop variables
    if varList is not None:
        # Make sure it's a list
        varList = list(varList)
        varList = varList + co_list
        varList = _rename_aliased(od, varList)

        # Compute missing variables
        od = _compute._add_missing_variables(od, varList)
        # Drop useless
        nvarlist = [v for v in od._ds.data_vars if v not in varList]
        od._ds = od._ds.drop_vars(nvarlist)
    else:  # this way, if applicable, llc_transf gets applied to all vars
        varList = [var for var in od._ds.reset_coords().data_vars]

    # Unpack
    ds = od._ds
    periodic = od.grid_periodic

    # ---------------------------
    # Time CUTOUT
    # ---------------------------
    # Initialize vertical mask
    maskT = _xr.ones_like(ds["time"]).astype("int")

    if timeRange is not None:

        # Use arrays
        timeRange = _np.asarray([_np.min(timeRange), _np.max(timeRange)]).astype(
            ds["time"].dtype
        )

        # Get the closest
        for i, time in enumerate(timeRange):
            if _np.issubdtype(ds["time"].dtype, _np.datetime64):
                diff = _np.fabs(ds["time"].astype("float64") - time.astype("float64"))
            else:
                diff = _np.fabs(ds["time"] - time)
            timeRange[i] = ds["time"].where(diff == diff.min(), drop=True).min().values
        maskT = maskT.where(
            _np.logical_and(ds["time"] >= timeRange[0], ds["time"] <= timeRange[-1]), 0
        )

        # Find time indexes
        maskT = maskT.assign_coords(time=_np.arange(len(maskT["time"])))
        dmaskT = maskT.where(maskT, drop=True)
        dtime = dmaskT["time"].values
        iT = [min(dtime), max(dtime)]
        maskT["time"] = ds["time"]

        # Indexis
        if iT[0] == iT[1]:
            if "time" not in dropAxes:
                if iT[0] > 0:
                    iT[0] = iT[0] - 1
                else:
                    iT[1] = iT[1] + 1
        else:
            dropAxes.pop("time", None)

        # Cutout
        ds = ds.isel(time=slice(iT[0], iT[1] + 1))
        if "time_midp" in ds.dims:
            if "time" in dropAxes:
                if iT[0] == len(ds["time_midp"]):
                    iT[0] = iT[0] - 1
                    iT[1] = iT[1] - 1
                ds = ds.isel(time_midp=slice(iT[0], iT[1] + 1))
            else:
                ds = ds.isel(time_midp=slice(iT[0], iT[1]))

    # ---------------------------
    # Vertical CUTOUT
    # ---------------------------
    # Initialize vertical mask
    maskV = _xr.ones_like(ds["Zp1"])

    if ZRange is not None:
        # Use arrays
        ZRange = _np.asarray([_np.min(ZRange) - add_Vbdr, _np.max(ZRange) + add_Vbdr])
        ZRange = ZRange.astype(ds["Zp1"].dtype)

        # Get the closest
        for i, Z in enumerate(ZRange):
            diff = _np.fabs(ds["Zp1"] - Z)
            ZRange[i] = ds["Zp1"].where(diff == diff.min()).min().values
        maskV = maskV.where(
            _np.logical_and(ds["Zp1"] >= ZRange[0], ds["Zp1"] <= ZRange[-1]), 0
        )

        # Find vertical indexes
        maskV = maskV.assign_coords(Zp1=_np.arange(len(maskV["Zp1"])))
        dmaskV = maskV.where(maskV, drop=True)
        dZp1 = dmaskV["Zp1"].values
        iZ = [_np.min(dZp1), _np.max(dZp1)]
        maskV["Zp1"] = ds["Zp1"]

        # Indexis
        if iZ[0] == iZ[1]:
            if "Z" not in dropAxes:
                if iZ[0] > 0:
                    iZ[0] = iZ[0] - 1
                else:
                    iZ[1] = iZ[1] + 1
        else:
            dropAxes.pop("Z", None)

        # Cutout
        ds = ds.isel(Zp1=slice(iZ[0], iZ[1] + 1))
        if "Z" in dropAxes:
            if iZ[0] == len(ds["Z"]):
                iZ[0] = iZ[0] - 1
                iZ[1] = iZ[1] - 1
            ds = ds.isel(Z=slice(iZ[0], iZ[1] + 1))
        else:
            ds = ds.isel(Z=slice(iZ[0], iZ[1]))

        if len(ds["Zp1"]) == 1:
            if "Zu" in ds.dims and len(ds["Zu"]) > 1:
                ds = ds.sel(Zu=ds["Zp1"].values, method="nearest")
            if "Zl" in ds.dims and len(ds["Zl"]) > 1:
                ds = ds.sel(Zl=ds["Zp1"].values, method="nearest")
        else:
            if "Zu" in ds.dims and len(ds["Zu"]) > 1:
                ds = ds.isel(Zu=slice(iZ[0], iZ[1]))
            if "Zl" in ds.dims and len(ds["Zl"]) > 1:
                ds = ds.isel(Zl=slice(iZ[0], iZ[1]))

    # ---------------------------
    # Horizontal CUTOUT (part I, split into two to avoid repeated code)
    # ---------------------------
    if add_Hbdr is True:
        add_Hbdr = _np.mean(
            [
                _np.fabs(od._ds["XG"].max() - od._ds["XG"].min()),
                _np.fabs(od._ds["YG"].max() - od._ds["YG"].min()),
            ]
        )
        add_Hbdr = add_Hbdr / _np.mean([len(od._ds["X"]), len(od._ds["Y"])])
    elif add_Hbdr is False:
        add_Hbdr = 0

    if add_Vbdr is True:
        add_Vbdr = _np.fabs(od._ds["Zp1"].diff("Zp1")).max().values
    elif add_Vbdr is False:
        add_Vbdr = 0

    # Initialize horizontal mask
    if XRange is not None or YRange is not None:

        maskH, dmaskH, XRange, YRange = get_maskH(
            ds, add_Hbdr, add_Vbdr, XRange, YRange
        )

    if transformation is not False and "face" in ds.dims:
        if XRange is None and YRange is None:
            faces = "all"
        else:
            faces = list(dmaskH["face"].values)  # gets faces that survives cutout
        _transf_list = ["arctic_crown"]
        if transformation in _transf_list:
            arg = {
                "ds": ds,
                "varlist": varList,  # vars and grid coords to transform
                "centered": centered,
                "faces": faces,
                "drop": True,  # required to calculate U-V grid points
                "chunks": chunks,
            }
            if transformation == "arctic_crown":
                _transformation = _llc_trans.arctic_crown
            dsnew = _transformation(**arg)
            dsnew = dsnew.set_coords(co_list)
            grid_coords = od.grid_coords
            od._ds = dsnew
            manipulate_coords = {"coordsUVfromG": True}
            new_face_connections = {"face_connections": {None: {None, None}}}
            od = od.set_face_connections(**new_face_connections)
            od = od.manipulate_coords(**manipulate_coords)
            if len(grid_coords["time"]) > 1:
                grid_coords["time"].pop("time_midp", None)
                grid_coords = {"add_midp": True, "grid_coords": grid_coords}
            od = od.set_grid_coords(**grid_coords, overwrite=True)
            od._ds.attrs["OceanSpy_description"] = "Cutout of"
            "simulation, with simple topology (face not a dimension)"
            # Unpack the new dataset without face as dimension
            ds = od._ds
            maskH, dmaskH, XRange, YRange = get_maskH(
                ds, add_Hbdr, add_Vbdr, XRange, YRange
            )
        elif transformation not in _transf_list:
            raise ValueError("transformation not supported")
    elif transformation is False and "face" in ds.dims:
        raise ValueError(
            "Must define a transformation to remove complex" "topology of dataset."
        )

    # ---------------------------
    # Horizontal CUTOUT part II (continuation of original code)
    # ---------------------------

    if XRange is not None or YRange is not None:
        dYp1 = dmaskH["Yp1"].values
        dXp1 = dmaskH["Xp1"].values
        iY = [_np.min(dYp1), _np.max(dYp1)]
        iX = [_np.min(dXp1), _np.max(dXp1)]
        maskH["Yp1"] = ds["Yp1"]
        maskH["Xp1"] = ds["Xp1"]

        # Original length
        lenY = len(ds["Yp1"])
        lenX = len(ds["Xp1"])

        # Indexis
        if iY[0] == iY[1]:
            if "Y" not in dropAxes:
                if iY[0] > 0:
                    iY[0] = iY[0] - 1
                else:
                    iY[1] = iY[1] + 1
        else:
            dropAxes.pop("Y", None)

        if iX[0] == iX[1]:
            if "X" not in dropAxes:
                if iX[0] > 0:
                    iX[0] = iX[0] - 1
                else:
                    iX[1] = iX[1] + 1
        else:
            dropAxes.pop("X", None)

        ds = ds.isel(Yp1=slice(iY[0], iY[1] + 1), Xp1=slice(iX[0], iX[1] + 1))

        Xcoords = od._grid.axes["X"].coords
        if "X" in dropAxes:
            if iX[0] == len(ds["X"]):
                iX[0] = iX[0] - 1
                iX[1] = iX[1] - 1
            ds = ds.isel(X=slice(iX[0], iX[1] + 1))
        elif ("outer" in Xcoords and Xcoords["outer"] == "Xp1") or (
            "left" in Xcoords and Xcoords["left"] == "Xp1"
        ):
            ds = ds.isel(X=slice(iX[0], iX[1]))
        elif "right" in Xcoords and Xcoords["right"] == "Xp1":
            ds = ds.isel(X=slice(iX[0] + 1, iX[1] + 1))

        Ycoords = od._grid.axes["Y"].coords
        if "Y" in dropAxes:
            if iY[0] == len(ds["Y"]):
                iY[0] = iY[0] - 1
                iY[1] = iY[1] - 1
            ds = ds.isel(Y=slice(iY[0], iY[1] + 1))
        elif ("outer" in Ycoords and Ycoords["outer"] == "Yp1") or (
            "left" in Ycoords and Ycoords["left"] == "Yp1"
        ):
            ds = ds.isel(Y=slice(iY[0], iY[1]))
        elif "right" in Ycoords and Ycoords["right"] == "Yp1":
            ds = ds.isel(Y=slice(iY[0] + 1, iY[1] + 1))

        # Cut axis can't be periodic
        if (len(ds["Yp1"]) < lenY or "Y" in dropAxes) and "Y" in periodic:
            periodic.remove("Y")
        if (len(ds["Xp1"]) < lenX or "X" in dropAxes) and "X" in periodic:
            periodic.remove("X")

    # ---------------------------
    # Horizontal MASK
    # ---------------------------

    if mask_outside and (YRange is not None or XRange is not None):
        if YRange is not None:
            minY = YRange[0]
            maxY = YRange[1]
        else:
            minY = ds["YG"].min().values
            maxY = ds["YG"].max().values
        if XRange is not None:
            minX = XRange[0]
            maxX = XRange[1]
        else:
            minX = ds["XG"].min().values
            maxX = ds["XG"].max().values

        maskC = _xr.where(
            _np.logical_and(
                _np.logical_and(ds["YC"] >= minY, ds["YC"] <= maxY),
                _np.logical_and(ds["XC"] >= minX, ds["XC"] <= maxX),
            ),
            1,
            0,
        ).persist()
        maskG = _xr.where(
            _np.logical_and(
                _np.logical_and(ds["YG"] >= minY, ds["YG"] <= maxY),
                _np.logical_and(ds["XG"] >= minX, ds["XG"] <= maxX),
            ),
            1,
            0,
        ).persist()
        maskU = _xr.where(
            _np.logical_and(
                _np.logical_and(ds["YU"] >= minY, ds["YU"] <= maxY),
                _np.logical_and(ds["XU"] >= minX, ds["XU"] <= maxX),
            ),
            1,
            0,
        ).persist()
        maskV = _xr.where(
            _np.logical_and(
                _np.logical_and(ds["YV"] >= minY, ds["YV"] <= maxY),
                _np.logical_and(ds["XV"] >= minX, ds["XV"] <= maxX),
            ),
            1,
            0,
        ).persist()

        for var in ds.data_vars:
            if set(["X", "Y"]).issubset(ds[var].dims):
                ds[var] = ds[var].where(maskC, drop=True)
            elif set(["Xp1", "Yp1"]).issubset(ds[var].dims):
                ds[var] = ds[var].where(maskG, drop=True)
            elif set(["Xp1", "Y"]).issubset(ds[var].dims):
                ds[var] = ds[var].where(maskU, drop=True)
            elif set(["X", "Yp1"]).issubset(ds[var].dims):
                ds[var] = ds[var].where(maskV, drop=True)

    # ---------------------------
    # TIME RESAMPLING
    # ---------------------------
    # Resample in time
    if timeFreq:

        # Infer original frequency
        inFreq = _pd.infer_freq(ds.time.values)
        if timeFreq[0].isdigit() and not inFreq[0].isdigit():
            inFreq = "1" + inFreq

        # Same frequency: Skip
        if timeFreq == inFreq:
            _warnings.warn(
                "\nInput time freq:"
                "[{}] = Output time frequency: [{}]:"
                "\nSkip time resampling."
                "".format(inFreq, timeFreq),
                stacklevel=2,
            )

        else:

            # Remove time_midp and warn
            vars2drop = [var for var in ds.variables if "time_midp" in ds[var].dims]
            if vars2drop:
                _warnings.warn(
                    "\nTime resampling drops variables"
                    " on `time_midp` dimension."
                    "\nDropped variables: {}."
                    "".format(vars2drop),
                    stacklevel=2,
                )
                ds = ds.drop_vars(vars2drop)

            # Snapshot
            if sampMethod == "snapshot":
                # Find new times
                time2sel = ds["time"].resample(time=timeFreq).first()
                newtime = ds["time"].sel(time=time2sel)

                # Use slice when possible
                inds = [
                    i for i, t in enumerate(ds["time"].values) if t in newtime.values
                ]
                inds_diff = _np.diff(inds)
                if all(inds_diff == inds_diff[0]):
                    ds = ds.isel(time=slice(inds[0], inds[-1] + 1, inds_diff[0]))
                else:
                    attrs = ds.attrs
                    ds = _xr.concat(
                        [ds.sel(time=time) for i, time in enumerate(newtime)],
                        dim="time",
                    )
                    ds.attrs = attrs

            else:
                # Mean
                # Separate time and timeless
                attrs = ds.attrs
                ds_dims = ds.drop_vars(
                    [var for var in ds.variables if var not in ds.dims]
                )
                ds_time = ds.drop_vars(
                    [var for var in ds.variables if "time" not in ds[var].dims]
                )
                ds_timeless = ds.drop_vars(
                    [var for var in ds.variables if "time" in ds[var].dims]
                )

                # Resample
                ds_time = ds_time.resample(time=timeFreq).mean("time")

                # Add all dimensions to ds, and fix attributes
                for dim in ds_time.dims:
                    if dim == "time":
                        ds_time[dim].attrs = ds_dims[dim].attrs
                    else:
                        ds_time[dim] = ds_dims[dim]

                # Merge
                ds = _xr.merge([ds_time, ds_timeless])
                ds.attrs = attrs

    # Update oceandataset
    od._ds = ds

    # Add time midp
    if timeFreq and "time" not in dropAxes:
        od = od.set_grid_coords(
            {**od.grid_coords, "time": {"time": -0.5}}, add_midp=True, overwrite=True
        )

    # Drop axes
    grid_coords = od.grid_coords
    for coord in list(grid_coords):
        if coord in dropAxes:
            grid_coords.pop(coord, None)
    od = od.set_grid_coords(grid_coords, overwrite=True)

    # Cut axis can't be periodic
    od = od.set_grid_periodic(periodic)

    return od
示例#35
0
 def test_merge_dicts_dims(self):
     actual = xr.merge([{'y': ('x', [13])}, {'x': [12]}])
     expected = xr.Dataset({'x': [12], 'y': ('x', [13])})
     assert actual.identical(expected)
示例#36
0
    :param ds_all:
    :return: date_range a string of values to use in the filename of the model
    """
    year_start = ds_all.time.data[0].astype('datetime64[Y]').astype(int) + 1970
    year_end = ds_all.time.data[len(ds_all.time.data) -
                                1].astype('datetime64[Y]').astype(int) + 1970
    month_start = ds_all.time.data[0].astype('datetime64[M]').astype(
        int) % 12 + 1
    month_end = ds_all.time.data[len(ds_all.time.data) - 1].astype(
        'datetime64[M]').astype(int) % 12 + 1
    date_range = str(year_start) + str(month_start).zfill(2) + '_' + str(
        year_end) + str(month_end).zfill(2)
    return date_range


ds_all = xr.merge([xr.open_dataset(f) for f in flist])
print('- Done')

#*******************************************************************************
#Creating output file
#*******************************************************************************
print('Creating output file')

ds_all.to_netcdf(fout)

print('- All monthly files from ' + MODEL + ' combined and saved to: ' + fout)

#*******************************************************************************
#End
#*******************************************************************************
示例#37
0
 def test_merge_arrays(self):
     data = create_test_data()
     actual = xr.merge([data.var1, data.var2])
     expected = data[['var1', 'var2']]
     assert actual.identical(expected)
示例#38
0
filepath_initial = rootdir+'experiments/'+model+'/'+experiment+'/ariane_initial.nc'
filepath_time = rootdir+'time/time_orca025_global_5d.mat'
filepath_region = rootdir+'experiments/'+model+'/quant_back_seedNAn1_t3560-sep-4217_sign27.7-28_MLrefz8delsig0.01/region_limits'
# Universal variables
spy = 365*24*60*60
yrst = 1958
yrend = 2016
ventsec = 7
lastinit = 4217

# Ariane input
ds_initial = xr.open_mfdataset(filepath_initial,combine='nested',concat_dim='ntraj')
ds_initial.init_volume.name = 'init_volume'
# Ariane output
ds = xr.open_mfdataset(filepath,combine='nested',concat_dim='ntraj')
ds = xr.merge([ds, ds_initial.init_volume])
ds['final_age'] = ds.final_age.astype('timedelta64[s]').astype('float64')/spy
ds['final_dens'] = calc_sigmantr(ds.final_temp,ds.final_salt)
# Model times
time_vals = np.append(np.array([0]),sio.loadmat(filepath_time)['time'].squeeze())
time = xr.DataArray(time_vals,dims=['nfile'],coords={'nfile':np.arange(time_vals.size)})
# Reagion limits
region_limits = np.loadtxt(filepath_region)

# Bins
years = np.arange(yrst,yrend+1)
ages = np.arange(-3/12,yrend-yrst+9/12)
densities = np.arange(27.7,28,0.01)
init_t_unique = np.unique(ds.init_t)
inits = np.append(init_t_unique-0.5,init_t_unique[-1]+0.5)
xs = np.arange(region_limits[0,0],region_limits[0,1])
示例#39
0
 def duplicate_and_merge(array):
     return xr.merge([array, array.rename("bar")]).to_array()
示例#40
0
    if pressure_adjust:
        ds = get_pressure_coord_fields(case,
                                       varlist,
                                       from_time,
                                       to_time,
                                       history_fld,
                                       model=model)
        return ds
    else:
        if varlist is not None:
            fl = []
            vl_lacking = []
            for var in varlist:
                fn = get_filename_ng_field(var, model, case, from_time, to_time)
                if os.path.isfile(fn):
                    fl.append(fn)
                else:
                    vl_lacking.append(var)
        else:
            vl_lacking=varlist

        ds = xr_import_NorESM(case, vl_lacking, from_time, to_time, path=raw_data_path,
                              model=model,
                              history_fld=history_fld,
                              comp=comp, chunks=chunks)
        ds = xr_fix(ds, model_name=model)
        if len(fl)>0:
            ds_f_file = xr.open_mfdataset(fl, combine='by_coords')
            ds = xr.merge([ds, ds_f_file])
        return ds
示例#41
0
def run(params):
    print(params)
    start_time = datetime.now()

    bin_width, filter_bandwidth, theta, shift, \
        signal_field, noise_field, noise_multiplier = params

    # Get file path
    signal_dir = '/scratch/pkittiwi/fg1p/signal_map/bin{:.2f}/' \
                 'fbw{:.2f}/theta{:.1f}/shift{:d}' \
        .format(bin_width, filter_bandwidth, theta, shift)
    noise_dir = '/scratch/pkittiwi/fg1p/noise_map/bin{:.2f}/' \
                'fbw{:.2f}/theta{:.1f}/shift{:d}' \
        .format(bin_width, filter_bandwidth, theta, shift)
    output_dir = '/scratch/pkittiwi/fg1p/obs_map/obsn{:.1f}/bin{:.2f}/' \
                 'fbw{:.2f}/theta{:.1f}/shift{:d}/s{:03d}' \
        .format(noise_multiplier, bin_width, filter_bandwidth, theta,
                shift, signal_field)
    signal_file = '{:s}/signal_map_bin{:.2f}_fbw{:.2f}_' \
                  'theta{:.1f}_shift{:d}_{:03d}.nc' \
        .format(signal_dir, bin_width, filter_bandwidth,
                theta, shift, signal_field)
    noise_file = '{:s}/noise_map_bin{:.2f}_fbw{:.2f}_' \
                 'theta{:.1f}_shift{:d}_{:03d}.nc' \
        .format(noise_dir, bin_width, filter_bandwidth,
                theta, shift, noise_field)
    output_file = '{:s}/obs_map_obsn{:.1f}_bin{:.2f}_fbw{:.2f}_' \
                  'theta{:.1f}_shift{:d}_{:03d}_{:03d}.nc' \
        .format(output_dir, noise_multiplier, bin_width, filter_bandwidth,
                theta, shift, signal_field, noise_field)

    # Load data
    with xr.open_dataarray(signal_file) as ds:
        signal = ds.load()
    with xr.open_dataarray(noise_file) as ds:
        noise = ds.load()
    with xr.open_dataarray('/scratch/pkittiwi/fg1p/hera331_fov_mask.nc') as ds:
        mask = ds.load()

    # Align coordinates - they must match for XArray broadcasting
    for key in ['x', 'y', 'f']:
        signal.coords[key] = noise.coords[key].values
        mask.coords[key] = noise.coords[key].values
    signal, noise, mask = xr.align(signal, noise, mask)

    # Make observation
    signal = signal.where(mask == 1)
    noise = noise.where(mask == 1) * noise_multiplier
    obs = signal + noise
    obs.name = 'obs'
    obs.attrs = {'signal_field': signal_field, 'noise_field': noise_field,
                 'noise_multiplier': noise_multiplier, 'bin_width': bin_width,
                 'filter_bandwidth': filter_bandwidth, 'theta': theta,
                 'shift': shift}

    # Calculate noise variance
    noise_var = noise.var(dim=['y', 'x'])
    noise_var.name = 'noise_var'
    noise_var.attrs = {
        'noise_field': noise_field, 'noise_multiplier': noise_multiplier,
        'bin_width': bin_width, 'filter_bandwidth': filter_bandwidth,
        'theta': theta, 'shift': shift
    }

    # Save output
    out = xr.merge([obs, noise_var])
    os.makedirs(output_dir, exist_ok=True)
    out.to_netcdf(output_file)

    print('Finish {:s}. Time spent: {:.5f} minutes'
          .format(output_file,
                  (datetime.now() - start_time).total_seconds() / 60))

    return 0
示例#42
0
def resample_in_time(cube: xr.Dataset,
                     frequency: str,
                     method: Union[str, Sequence[str]],
                     offset=None,
                     base: int = 0,
                     tolerance=None,
                     interp_kind=None,
                     time_chunk_size=None,
                     var_names: Sequence[str] = None,
                     metadata: Dict[str, Any] = None,
                     cube_asserted: bool = False) -> xr.Dataset:
    """
    Resample a xcube dataset in the time dimension.

    The argument *method* may be one or a sequence of ``'all'``, ``'any'``,
    ``'argmax'``, ``'argmin'``, ``'argmax'``, ``'count'``,
    ``'first'``, ``'last'``, ``'max'``, ``'min'``, ``'mean'``, ``'median'``,
    ``'percentile_<p>'``, ``'std'``, ``'sum'``, ``'var'``.

    In value ``'percentile_<p>'`` is a placeholder, where ``'<p>'`` must be replaced by an
    integer percentage value, e.g. ``'percentile_90'`` is the 90%-percentile.

    *Important note:* As of xarray 0.14 and dask 2.8, the methods ``'median'`` and ``'percentile_<p>'`
    cannot be used if the variables in *cube* comprise chunked dask arrays.
    In this case, use the ``compute()`` or ``load()`` method to convert dask arrays into numpy arrays.

    :param cube: The xcube dataset.
    :param frequency: Temporal aggregation frequency. Use format "<count><offset>"
        "where <offset> is one of 'H', 'D', 'W', 'M', 'Q', 'Y'.
    :param method: Resampling method or sequence of resampling methods.
    :param offset: Offset used to adjust the resampled time labels.
        Uses same syntax as *frequency*.
    :param base: For frequencies that evenly subdivide 1 day, the "origin" of the
        aggregated intervals. For example, for '24H' frequency, base could range from 0 through 23.
    :param time_chunk_size: If not None, the chunk size to be used for the "time" dimension.
    :param var_names: Variable names to include.
    :param tolerance: Time tolerance for selective upsampling methods. Defaults to *frequency*.
    :param interp_kind: Kind of interpolation if *method* is 'interpolation'.
    :param metadata: Output metadata.
    :param cube_asserted: If False, *cube* will be verified, otherwise it is expected to be a valid cube.
    :return: A new xcube dataset resampled in time.
    """
    if not cube_asserted:
        assert_cube(cube)

    if frequency == 'all':
        time_gap = np.array(cube.time[-1]) - np.array(cube.time[0])
        days = int((np.timedelta64(time_gap, 'D') / np.timedelta64(1, 'D')) +
                   1)
        frequency = f'{days}D'

    if var_names:
        cube = select_variables_subset(cube, var_names)

    resampler = cube.resample(skipna=True,
                              closed='left',
                              label='left',
                              keep_attrs=True,
                              time=frequency,
                              loffset=offset,
                              base=base)

    if isinstance(method, str):
        methods = [method]
    else:
        methods = list(method)

    percentile_prefix = 'percentile_'

    resampled_cubes = []
    for method in methods:
        method_args = []
        method_postfix = method
        if method.startswith(percentile_prefix):
            p = int(method[len(percentile_prefix):])
            q = p / 100.0
            method_args = [q]
            method_postfix = f'p{p}'
            method = 'quantile'
        resampling_method = getattr(resampler, method)
        method_kwargs = get_method_kwargs(method, frequency, interp_kind,
                                          tolerance)
        resampled_cube = resampling_method(*method_args, **method_kwargs)
        resampled_cube = resampled_cube.rename({
            var_name: f'{var_name}_{method_postfix}'
            for var_name in resampled_cube.data_vars
        })
        resampled_cubes.append(resampled_cube)

    if len(resampled_cubes) == 1:
        resampled_cube = resampled_cubes[0]
    else:
        resampled_cube = xr.merge(resampled_cubes)

    # TODO: add time_bnds to resampled_ds
    time_coverage_start = '%s' % cube.time[0]
    time_coverage_end = '%s' % cube.time[-1]

    resampled_cube.attrs.update(metadata or {})
    # TODO: add other time_coverage_ attributes
    resampled_cube.attrs.update(time_coverage_start=time_coverage_start,
                                time_coverage_end=time_coverage_end)

    schema = CubeSchema.new(cube)
    chunk_sizes = {
        schema.dims[i]: schema.chunks[i]
        for i in range(schema.ndim)
    }

    if isinstance(time_chunk_size, int) and time_chunk_size >= 0:
        chunk_sizes['time'] = time_chunk_size

    return resampled_cube.chunk(chunk_sizes)
示例#43
0
 def test_merge_alignment_error(self):
     ds = xr.Dataset(coords={'x': [1, 2]})
     other = xr.Dataset(coords={'x': [2, 3]})
     with raises_regex(ValueError, 'indexes .* not equal'):
         xr.merge([ds, other], join='exact')
示例#44
0
We'll use :func:`harmonica.bouguer_correction` to calculate a topography-free gravity
disturbance for Earth using our sample gravity and topography data. One thing to note is
that the ETOPO1 topography is referenced to the geoid, not the ellipsoid. Since we want
to remove the masses between the surface of the Earth and ellipsoid, we need to add the
geoid height to the topography before Bouguer correction.
"""
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import xarray as xr
import harmonica as hm

# Load the global gravity, topography, and geoid grids
data = xr.merge([
    hm.datasets.fetch_gravity_earth(),
    hm.datasets.fetch_geoid_earth(),
    hm.datasets.fetch_topography_earth(),
])
print(data)

# Calculate normal gravity and the disturbance
gamma = hm.normal_gravity(data.latitude, data.height_over_ell)
disturbance = data.gravity - gamma

# Reference the topography to the ellipsoid
topography_ell = data.topography + data.geoid

# Calculate the Bouguer planar correction and the topography-free disturbance. Use the
# default densities for the crust and ocean water.
bouguer = hm.bouguer_correction(topography_ell)
disturbance_topofree = disturbance - bouguer
示例#45
0
 def test_merge_error(self):
     ds = xr.Dataset({'x': 0})
     with self.assertRaises(xr.MergeError):
         xr.merge([ds, ds + 1])
示例#46
0
def convert_ms(infile,
               outfile=None,
               ddi=None,
               compressor=None,
               chunk_shape=(100, 400, 20, 1),
               nofile=False):
    """
    Convert legacy format MS to xarray Visibility Dataset and zarr storage format

    This function requires CASA6 casatools module.

    Parameters
    ----------
    infile : str
        Input MS filename
    outfile : str
        Output zarr filename. If None, will use infile name with .vis.zarr extension
    ddi : int
        Specific ddi to convert. Leave as None to convert entire MS
    compressor : numcodecs.blosc.Blosc
        The blosc compressor to use when saving the converted data to disk using zarr.
        If None the zstd compression algorithm used with compression level 2.
    chunk_shape: 4-D tuple of ints
        Shape of desired chunking in the form of (time, baseline, channel, polarization), use -1 for entire axis in one chunk. Default is (100, 400, 20, 1)
        Note: chunk size is the product of the four numbers, and data is batch processed by time axis, so that will drive memory needed for conversion.
    nofile : bool
        Allows legacy MS to be directly read without file conversion. If set to true, no output file will be written and entire MS will be held in memory.
        Requires ~4x the memory of the MS size.  Default is False
    Returns
    -------
    list of xarray.core.dataset.Dataset
      List of new xarray Datasets of Visibility data contents. One element in list per DDI plus the metadata global.
    """
    import os
    from casatools import table as tb
    from numcodecs import Blosc
    import pandas as pd
    import xarray
    import numpy as np
    import time
    from itertools import cycle
    import warnings
    warnings.filterwarnings('ignore', category=FutureWarning)

    if compressor is None:
        compressor = Blosc(cname='zstd', clevel=2, shuffle=0)

    # parse filename to use
    infile = os.path.expanduser(infile)
    prefix = infile[:infile.rindex('.')]
    if outfile is None:
        outfile = prefix + '.vis.zarr'
    else:
        outfile = os.path.expanduser(outfile)

    # need to manually remove existing zarr file (if any)
    print('processing %s ' % infile)
    if not nofile:
        os.system("rm -fr " + outfile)
        os.system("mkdir " + outfile)
    start = time.time()

    # let's assume that each DATA_DESC_ID (ddi) is a fixed shape that may differ from others
    # form a list of ddis to process, each will be placed it in its own xarray dataset and partition
    ddis = [ddi]
    if ddi is None:
        MS = tb(infile)
        MS.open(infile, nomodify=True, lockoptions={'option': 'usernoread'})
        ddis = MS.taql('select distinct DATA_DESC_ID from %s' % prefix +
                       '.ms').getcol('DATA_DESC_ID')
        MS.close()

    if compressor is None:
        compressor = Blosc(cname='zstd', clevel=2, shuffle=0)

    # initialize list of xarray datasets to be returned by this function
    xds_list = []

    # helper for normalizing variable column shapes
    def apad(arr, ts):
        return np.pad(arr,
                      (0, ts[0] - arr.shape[0])) if arr.ndim == 1 else np.pad(
                          arr, ((0, ts[0] - arr.shape[0]),
                                (0, ts[1] - arr.shape[1])))

    # helper for reading time columns to datetime format
    # pandas datetimes are referenced against a 0 of 1970-01-01
    # CASA's modified julian day reference time is (of course) 1858-11-17
    # this requires a correction of 3506716800 seconds which is hardcoded to save time
    def convert_time(rawtimes):
        # correction = (pd.to_datetime(0, unit='s') - pd.to_datetime('1858-11-17', format='%Y-%m-%d')).total_seconds()
        correction = 3506716800.0
        # return rawtimes
        return pd.to_datetime(np.array(rawtimes) - correction, unit='s').values

    ############################################
    # build combined metadata xarray dataset from each table in the ms directory (other than main)
    # - we want as much as possible to be stored as data_vars with appropriate coordinates
    # - whenever possible, meaningless id fields are replaced with string names as the coordinate index
    # - some things that are too variably structured will have to go in attributes
    # - this pretty much needs to be done individually for each table, some generalization is possible but it makes things too complex
    ############################################
    mvars, mcoords, mattrs = {}, {}, {}
    tables = [
        'DATA_DESCRIPTION', 'SPECTRAL_WINDOW', 'POLARIZATION', 'SORTED_TABLE'
    ]  # initialize to things we don't want to process now
    ms_meta = tb()

    ## ANTENNA table
    tables += ['ANTENNA']
    print('processing support table %s' % tables[-1], end='\r')
    if os.path.isdir(os.path.join(infile, tables[-1])):
        ms_meta.open(os.path.join(infile, tables[-1]),
                     nomodify=True,
                     lockoptions={'option': 'usernoread'})
        mcoords['antenna'] = list(range(ms_meta.nrows()))
        for col in ms_meta.colnames():
            if not ms_meta.iscelldefined(col, 0): continue
            data = ms_meta.getcol(col).transpose()
            if data.ndim == 1:
                mvars['ANT_' + col] = xarray.DataArray(data, dims=['antenna'])
            else:
                mvars['ANT_' + col] = xarray.DataArray(
                    data, dims=['antenna', 'd' + str(data.shape[1])])
        ms_meta.close()

    ## FEED table
    tables += ['FEED']
    print('processing support table %s' % tables[-1], end='\r')
    if os.path.isdir(os.path.join(infile, tables[-1])):
        ms_meta.open(os.path.join(infile, tables[-1]),
                     nomodify=True,
                     lockoptions={'option': 'usernoread'})
        if ms_meta.nrows() > 0:
            mcoords['spw'] = np.arange(
                np.max(ms_meta.getcol('SPECTRAL_WINDOW_ID')) + 1)
            mcoords['feed'] = np.arange(np.max(ms_meta.getcol('FEED_ID')) + 1)
            mcoords['receptors'] = np.arange(
                np.max(ms_meta.getcol('NUM_RECEPTORS')) + 1)
            antidx, spwidx, feedidx = ms_meta.getcol(
                'ANTENNA_ID'), ms_meta.getcol(
                    'SPECTRAL_WINDOW_ID'), ms_meta.getcol('FEED_ID')
            if ms_meta.nrows() != (len(np.unique(antidx)) * len(
                    np.unique(spwidx)) * len(np.unique(feedidx))):
                print('WARNING: index mismatch in %s table' % tables[-1])
            for col in ms_meta.colnames():
                if not ms_meta.iscelldefined(col, 0): continue
                if col in ['SPECTRAL_WINDOW_ID', 'ANTENNA_ID', 'FEED_ID']:
                    continue
                if ms_meta.isvarcol(col):
                    tshape, tdim = (len(
                        mcoords['receptors']), ), ('receptors', )
                    if col == 'BEAM_OFFSET':
                        tshape, tdim = (2, len(
                            mcoords['receptors'])), ('d2', 'receptors')
                    elif col == 'POL_RESPONSE':
                        tshape, tdim = (len(
                            mcoords['receptors']), len(
                                mcoords['receptors'])), ('receptors',
                                                         'receptors')
                    data = ms_meta.getvarcol(col)
                    data = np.array([
                        apad(data['r' + str(kk)][..., 0], tshape)
                        for kk in np.arange(len(data)) + 1
                    ])
                    metadata = np.full(
                        (len(mcoords['spw']), len(mcoords['antenna']),
                         len(mcoords['feed'])) + tshape,
                        np.nan,
                        dtype=data.dtype)
                    metadata[spwidx, antidx, feedidx] = data
                    mvars['FEED_' + col] = xarray.DataArray(
                        metadata, dims=['spw', 'antenna', 'feed'] + list(tdim))
                else:
                    data = ms_meta.getcol(col).transpose()
                    if col == 'TIME': data = convert_time(data)
                    if data.ndim == 1:
                        metadata = np.full(
                            (len(mcoords['spw']), len(
                                mcoords['antenna']), len(mcoords['feed'])),
                            np.nan,
                            dtype=data.dtype)
                        metadata[spwidx, antidx, feedidx] = data
                        mvars['FEED_' + col] = xarray.DataArray(
                            metadata, dims=['spw', 'antenna', 'feed'])
                    else:  # only POSITION should trigger this
                        metadata = np.full(
                            (len(mcoords['spw']), len(mcoords['antenna']),
                             len(mcoords['feed']), data.shape[1]),
                            np.nan,
                            dtype=data.dtype)
                        metadata[spwidx, antidx, feedidx] = data
                        mvars['FEED_' + col] = xarray.DataArray(
                            metadata,
                            dims=[
                                'spw', 'antenna', 'feed',
                                'd' + str(data.shape[1])
                            ])
        ms_meta.close()

    ## FIELD table
    tables += ['FIELD']
    print('processing support table %s' % tables[-1], end='\r')
    if os.path.isdir(os.path.join(infile, tables[-1])):
        ms_meta.open(os.path.join(infile, tables[-1]),
                     nomodify=True,
                     lockoptions={'option': 'usernoread'})
        if ms_meta.nrows() > 0:
            funique, fidx, fcount = np.unique(ms_meta.getcol('NAME'),
                                              return_inverse=True,
                                              return_counts=True)
            mcoords['field'] = [
                funique[ii] if fcount[ii] == 1 else funique[ii] +
                ' (%s)' % str(nn) for nn, ii in enumerate(fidx)
            ]
            max_poly = np.max(
                ms_meta.taql('select distinct NUM_POLY from %s' % os.path.join(
                    infile, tables[-1])).getcol('NUM_POLY')) + 1
            tshape = (2, max_poly)
            for col in ms_meta.colnames():
                if col in ['NAME']: continue
                if not ms_meta.iscelldefined(col, 0): continue
                if ms_meta.isvarcol(col):
                    data = ms_meta.getvarcol(col)
                    data = np.array([
                        apad(data['r' + str(kk)][..., 0], tshape)
                        for kk in np.arange(len(data)) + 1
                    ])
                    mvars['FIELD_' + col] = xarray.DataArray(
                        data, dims=['field', 'd2', 'd' + str(max_poly)])
                else:
                    data = ms_meta.getcol(col).transpose()
                    if col == 'TIME': data = convert_time(data)
                    mvars['FIELD_' + col] = xarray.DataArray(data,
                                                             dims=['field'])
        ms_meta.close()

    ## FLAG_CMD table
    tables += ['FLAG_CMD']
    print('processing support table %s' % tables[-1], end='\r')
    if os.path.isdir(os.path.join(infile, tables[-1])):
        ms_meta.open(os.path.join(infile, tables[-1]),
                     nomodify=True,
                     lockoptions={'option': 'usernoread'})
        if ms_meta.nrows() > 0:
            mcoords['time_fcmd'], timeidx = np.unique(convert_time(
                ms_meta.getcol('TIME')),
                                                      return_inverse=True)
            for col in ms_meta.colnames():
                if not ms_meta.iscelldefined(col, 0): continue
                if col in ['TIME']: continue
                data = ms_meta.getcol(col).transpose()
                metadata = np.full((len(mcoords['time_fcmd'])),
                                   np.nan,
                                   dtype=data.dtype)
                metadata[timeidx] = data
                mvars['FCMD_' + col] = xarray.DataArray(metadata,
                                                        dims=['time_fcmd'])
        ms_meta.close()

    ## HISTORY table
    tables += ['HISTORY']
    print('processing support table %s' % tables[-1], end='\r')
    if os.path.isdir(os.path.join(infile, tables[-1])):
        ms_meta.open(os.path.join(infile, tables[-1]),
                     nomodify=True,
                     lockoptions={'option': 'usernoread'})
        if ms_meta.nrows() > 0:
            mcoords['time_hist'], timeidx = np.unique(convert_time(
                ms_meta.getcol('TIME')),
                                                      return_inverse=True)
            for col in ms_meta.colnames():
                if not ms_meta.iscelldefined(col, 0): continue
                if col in ['TIME', 'CLI_COMMAND', 'APP_PARAMS']:
                    continue  # cli_command and app_params are var cols that wont work
                data = ms_meta.getcol(col).transpose()
                metadata = np.full((len(mcoords['time_hist'])),
                                   np.nan,
                                   dtype=data.dtype)
                metadata[timeidx] = data
                mvars['FCMD_' + col] = xarray.DataArray(metadata,
                                                        dims=['time_hist'])
        ms_meta.close()

    ## OBSERVATION table
    tables += ['OBSERVATION']
    print('processing support table %s' % tables[-1], end='\r')
    if os.path.isdir(os.path.join(infile, tables[-1])):
        ms_meta.open(os.path.join(infile, tables[-1]),
                     nomodify=True,
                     lockoptions={'option': 'usernoread'})
        if ms_meta.nrows() > 0:
            funique, fidx, fcount = np.unique(ms_meta.getcol('PROJECT'),
                                              return_inverse=True,
                                              return_counts=True)
            mcoords['observation'] = [
                funique[ii] if fcount[ii] == 1 else funique[ii] +
                ' (%s)' % str(nn) for nn, ii in enumerate(fidx)
            ]
            for col in ms_meta.colnames():
                if not ms_meta.iscelldefined(col, 0): continue
                if col in ['PROJECT', 'LOG', 'SCHEDULE']:
                    continue  # log and schedule are var cols that wont work
                data = ms_meta.getcol(col).transpose()
                if col == 'TIME_RANGE':
                    data = np.hstack((convert_time(data[:, 0])[:, None],
                                      convert_time(data[:, 1])[:, None]))
                    mvars['OBS_' + col] = xarray.DataArray(
                        data, dims=['observation', 'd2'])
                else:
                    mvars['OBS_' + col] = xarray.DataArray(
                        data, dims=['observation'])
        ms_meta.close()

    ## POINTING table
    tables += ['POINTING']
    print('processing support table %s' % tables[-1], end='\r')
    if os.path.isdir(os.path.join(infile, tables[-1])):
        ms_meta.open(os.path.join(infile, tables[-1]),
                     nomodify=True,
                     lockoptions={'option': 'usernoread'})
        if ms_meta.nrows() > 0:
            mcoords['time_point'], timeidx = np.unique(convert_time(
                ms_meta.getcol('TIME')),
                                                       return_inverse=True)
            antidx = ms_meta.getcol('ANTENNA_ID')
            for col in ms_meta.colnames():
                if col in ['TIME', 'ANTENNA_ID']: continue
                if not ms_meta.iscelldefined(col, 0): continue
                try:  # can't use getvarcol as it dies on large tables like this
                    data = ms_meta.getcol(col).transpose()
                    if data.ndim == 1:
                        metadata = np.full((len(
                            mcoords['time_point']), len(mcoords['antenna'])),
                                           np.nan,
                                           dtype=data.dtype)
                        metadata[timeidx, antidx] = data
                        mvars['POINT_' + col] = xarray.DataArray(
                            metadata, dims=['time_point', 'antenna'])
                    if data.ndim > 1:
                        metadata = np.full(
                            (len(mcoords['time_point']), len(
                                mcoords['antenna'])) + data.shape[1:],
                            np.nan,
                            dtype=data.dtype)
                        metadata[timeidx, antidx] = data
                        mvars['POINT_' + col] = xarray.DataArray(
                            metadata,
                            dims=['time_point', 'antenna'] +
                            ['d' + str(ii) for ii in data.shape[1:]])
                except Exception:
                    print('WARNING : unable to process col %s of table %s' %
                          (col, tables[-1]))
        ms_meta.close()

    ## PROCESSOR table
    tables += ['PROCESSOR']
    print('processing support table %s' % tables[-1], end='\r')
    if os.path.isdir(os.path.join(infile, tables[-1])):
        ms_meta.open(os.path.join(infile, tables[-1]),
                     nomodify=True,
                     lockoptions={'option': 'usernoread'})
        if ms_meta.nrows() > 0:
            funique, fidx, fcount = np.unique(ms_meta.getcol('TYPE'),
                                              return_inverse=True,
                                              return_counts=True)
            mcoords['processor'] = [
                funique[ii] if fcount[ii] == 1 else funique[ii] +
                ' (%s)' % str(nn) for nn, ii in enumerate(fidx)
            ]
            for col in ms_meta.colnames():
                if not ms_meta.iscelldefined(col, 0): continue
                if col in ['TYPE']: continue
                if not ms_meta.isvarcol(col):
                    data = ms_meta.getcol(col).transpose()
                    mvars['PROC_' + col] = xarray.DataArray(data,
                                                            dims=['processor'])
        ms_meta.close()

    ## SOURCE table
    tables += ['SOURCE']
    print('processing support table %s' % tables[-1], end='\r')
    if os.path.isdir(os.path.join(infile, tables[-1])):
        ms_meta.open(os.path.join(infile, tables[-1]),
                     nomodify=True,
                     lockoptions={'option': 'usernoread'})
        if ms_meta.nrows() > 0:
            mcoords['source'] = np.unique(ms_meta.getcol('SOURCE_ID'))
            max_lines = np.max(
                ms_meta.taql(
                    'select distinct NUM_LINES from %s' %
                    os.path.join(infile, tables[-1])).getcol('NUM_LINES'))
            srcidx, spwidx = ms_meta.getcol('SOURCE_ID'), ms_meta.getcol(
                'SPECTRAL_WINDOW_ID')
            tshape = (2, max_lines)
            for col in ms_meta.colnames():
                try:
                    if col in ['SOURCE_ID', 'SPECTRAL_WINDOW_ID']: continue
                    if not ms_meta.iscelldefined(col, 0): continue
                    if ms_meta.isvarcol(col) and (tshape[1] > 0) and (
                            col
                            not in ['POSITION', 'SOURCE_MODEL', 'PULSAR_ID']):
                        data = ms_meta.getvarcol(col)
                        data = np.array([
                            apad(data['r' + str(kk)][..., 0], tshape)
                            for kk in np.arange(len(data)) + 1
                        ])
                        metadata = np.full(
                            (len(mcoords['spw']), len(mcoords['source'])) +
                            tshape,
                            np.nan,
                            dtype=data.dtype)
                        metadata[spwidx, srcidx] = data
                        mvars['SRC_' + col] = xarray.DataArray(
                            metadata,
                            dims=['spw', 'source', 'd' + str(max_lines)])
                    else:
                        data = ms_meta.getcol(col).transpose()
                        if col == 'TIME': data = convert_time(data)
                        if data.ndim == 1:
                            metadata = np.full(
                                (len(mcoords['spw']), len(mcoords['source'])),
                                np.nan,
                                dtype=data.dtype)
                            metadata[spwidx, srcidx] = data
                            mvars['SRC_' + col] = xarray.DataArray(
                                metadata, dims=['spw', 'source'])
                        else:
                            metadata = np.full(
                                (len(mcoords['spw']), len(
                                    mcoords['source']), data.shape[1]),
                                np.nan,
                                dtype=data.dtype)
                            metadata[spwidx, srcidx] = data
                            mvars['SRC_' + col] = xarray.DataArray(
                                metadata,
                                dims=[
                                    'spw', 'source', 'd' + str(data.shape[1])
                                ])
                except Exception:
                    print('WARNING : unable to process col %s of table %s' %
                          (col, tables[-1]))
        ms_meta.close()

    ## STATE table
    tables += ['STATE']
    print('processing support table %s' % tables[-1], end='\r')
    if os.path.isdir(os.path.join(infile, tables[-1])):
        ms_meta.open(os.path.join(infile, tables[-1]),
                     nomodify=True,
                     lockoptions={'option': 'usernoread'})
        if ms_meta.nrows() > 0:
            funique, fidx, fcount = np.unique(ms_meta.getcol('OBS_MODE'),
                                              return_inverse=True,
                                              return_counts=True)
            mcoords['state'] = [
                funique[ii] if fcount[ii] == 1 else funique[ii] +
                ' (%s)' % str(nn) for nn, ii in enumerate(fidx)
            ]
            for col in ms_meta.colnames():
                if not ms_meta.iscelldefined(col, 0): continue
                if col in ['OBS_MODE']: continue
                if not ms_meta.isvarcol(col):
                    data = ms_meta.getcol(col).transpose()
                    mvars['STATE_' + col] = xarray.DataArray(data,
                                                             dims=['state'])
        ms_meta.close()

    # remaining junk for the attributes section
    other_tables = [
        tt for tt in os.listdir(infile)
        if os.path.isdir(os.path.join(infile, tt)) and tt not in tables
    ]
    other_tables = dict([(tt, tt[:4] + '_') for tt in other_tables])
    for ii, tt in enumerate(other_tables.keys()):
        print('processing support table %s of %s : %s' %
              (str(ii), str(len(other_tables.keys())), tt),
              end='\r')
        ms_meta.open(os.path.join(infile, tt),
                     nomodify=True,
                     lockoptions={'option': 'usernoread'})
        if ms_meta.nrows() == 0: continue
        for col in ms_meta.colnames():
            if not ms_meta.iscelldefined(col, 0): continue
            if ms_meta.isvarcol(col):
                data = ms_meta.getvarcol(col)
                data = [
                    data['r' + str(kk)].tolist()
                    if not isinstance(data['r' + str(kk)], bool) else []
                    for kk in np.arange(len(data)) + 1
                ]
                mattrs[other_tables[tt] + col] = data
            else:
                data = ms_meta.getcol(col).transpose()
                mattrs[other_tables[tt] + col] = data.tolist()
        ms_meta.close()

    # write the global meta data to a separate global partition in the zarr output directory
    mxds = xarray.Dataset(mvars, coords=mcoords, attrs=mattrs)
    if not nofile:
        print('writing global partition')
        mxds.to_zarr(outfile + '/global', mode='w')

    xds_list += [mxds]  # first item returned is always the global metadata
    print('meta data processing time ', time.time() - start)

    ####################################################################
    # process each selected DDI from the input MS, assume a fixed shape within the ddi (should always be true)
    # each DDI is written to its own subdirectory under the parent folder
    for ddi in ddis:
        print('**********************************')
        print('Processing ddi', ddi)
        start_ddi = time.time()

        # Open measurement set (ms) select ddi and sort main table by TIME,ANTENNA1,ANTENNA2
        tb_tool = tb()
        tb_tool.open(infile,
                     nomodify=True,
                     lockoptions={'option':
                                  'usernoread'})  # allow concurrent reads
        ms_ddi = tb_tool.taql(
            'select * from %s where DATA_DESC_ID = %s ORDERBY TIME,ANTENNA1,ANTENNA2'
            % (infile, str(ddi)))
        print('Selecting and sorting time ', time.time() - start_ddi)
        start_ddi = time.time()

        tdata = ms_ddi.getcol('TIME')
        times = convert_time(tdata)
        unique_times, time_changes, time_idxs = np.unique(times,
                                                          return_index=True,
                                                          return_inverse=True)
        n_time = unique_times.shape[0]

        ant1_col = np.array(ms_ddi.getcol('ANTENNA1'))
        ant2_col = np.array(ms_ddi.getcol('ANTENNA2'))
        ant1_ant2 = np.hstack((ant1_col[:, np.newaxis], ant2_col[:,
                                                                 np.newaxis]))
        unique_baselines, baseline_idxs = np.unique(ant1_ant2,
                                                    axis=0,
                                                    return_inverse=True)
        n_baseline = unique_baselines.shape[0]

        # look up spw and pol ids as starting point
        tb_tool_meta = tb()
        tb_tool_meta.open(infile + "/DATA_DESCRIPTION",
                          nomodify=True,
                          lockoptions={'option': 'usernoread'})
        spw_id = tb_tool_meta.getcol("SPECTRAL_WINDOW_ID")[ddi]
        pol_id = tb_tool_meta.getcol("POLARIZATION_ID")[ddi]
        tb_tool_meta.close()

        ###################
        # build metadata structure from remaining spw-specific table fields
        aux_coords = {
            'time': unique_times,
            'spw': np.array([spw_id]),
            'antennas': (['baseline', 'pair'], unique_baselines)
        }
        meta_attrs = {
            'DDI': ddi,
            'AUTO_CORRELATIONS': int(np.any(ant1_col == ant2_col))
        }
        tb_tool_meta.open(os.path.join(infile, 'SPECTRAL_WINDOW'),
                          nomodify=True,
                          lockoptions={'option': 'usernoread'})
        for col in tb_tool_meta.colnames():
            try:
                if not tb_tool_meta.iscelldefined(col, spw_id): continue
                if col in ['FLAG_ROW']: continue
                if col in [
                        'CHAN_FREQ', 'CHAN_WIDTH', 'EFFECTIVE_BW', 'RESOLUTION'
                ]:
                    aux_coords[col.lower()] = ('chan',
                                               tb_tool_meta.getcol(
                                                   col, spw_id, 1)[:, 0])
                else:
                    meta_attrs[col] = tb_tool_meta.getcol(col, spw_id,
                                                          1).transpose()[0]
            except Exception:
                print('WARNING : unable to process col %s of table %s' %
                      (col, tables[-1]))
        tb_tool_meta.close()

        tb_tool_meta.open(os.path.join(infile, 'POLARIZATION'),
                          nomodify=True,
                          lockoptions={'option': 'usernoread'})
        for col in tb_tool_meta.colnames():
            if col == 'CORR_TYPE':
                aux_coords[col.lower()] = ('pol',
                                           tb_tool_meta.getcol(col, pol_id,
                                                               1)[:, 0])
            elif col == 'CORR_PRODUCT':
                aux_coords[col.lower()] = (['receptor', 'pol'],
                                           tb_tool_meta.getcol(col, pol_id,
                                                               1)[:, :, 0])
        tb_tool_meta.close()

        n_chan = len(aux_coords['chan_freq'][1])
        n_pol = len(aux_coords['corr_type'][1])

        # overwrite chunk shape axis with -1 values
        ddi_chunk_shape = [
            cs if cs > 0 else [n_time, n_baseline, n_chan, n_pol][ci]
            for ci, cs in enumerate(chunk_shape)
        ]
        # if not writing to file, entire main table will be read in to memory at once
        batchsize = n_time if nofile else ddi_chunk_shape[0]

        print('n_time:', n_time, '  n_baseline:', n_baseline, '  n_chan:',
              n_chan, '  n_pol:', n_pol, ' chunking: ', ddi_chunk_shape,
              ' batchsize: ', batchsize)

        coords = {
            'time': unique_times,
            'baseline': np.arange(n_baseline),
            'chan': aux_coords.pop('chan_freq')[1],
            'pol': aux_coords.pop('corr_type')[1],
            'uvw_index': np.array(['uu', 'vv', 'ww'])
        }

        ###################
        # main table loop over each batch
        for cc, start_row_indx in enumerate(range(0, n_time, batchsize)):
            rtestimate = ', remaining time est %s s' % str(
                int(((time.time() - start_ddi) / cc) *
                    (n_time / batchsize - cc))) if cc > 0 else ''
            print('processing chunk %s of %s' %
                  (str(cc), str(n_time // batchsize)) + rtestimate,
                  end='\r')
            chunk = np.arange(min(batchsize,
                                  n_time - start_row_indx)) + start_row_indx
            chunk_time_changes = time_changes[chunk] - time_changes[chunk[
                0]]  # indices in this chunk of data where time value changes
            end_idx = time_changes[
                chunk[-1] +
                1] if chunk[-1] + 1 < len(time_changes) else len(time_idxs)
            idx_range = np.arange(
                time_changes[chunk[0]],
                end_idx)  # indices (rows) in main table to be read
            coords.update({'time': unique_times[chunk]})

            chunkdata = {}
            for col in ms_ddi.colnames():
                if col in ['DATA_DESC_ID', 'TIME', 'ANTENNA1', 'ANTENNA2']:
                    continue
                if not ms_ddi.iscelldefined(col, idx_range[0]): continue

                data = ms_ddi.getcol(col, idx_range[0],
                                     len(idx_range)).transpose()

                if col in 'UVW':  # n_row x 3 -> n_time x n_baseline x 3
                    fulldata = np.full((len(chunk), n_baseline, data.shape[1]),
                                       np.nan,
                                       dtype=data.dtype)
                    fulldata[time_idxs[idx_range] - chunk[0],
                             baseline_idxs[idx_range], :] = data
                    chunkdata[col] = xarray.DataArray(
                        fulldata, dims=['time', 'baseline', 'uvw_index'])

                elif data.ndim == 1:  # n_row -> n_time x n_baseline
                    if col == 'FIELD_ID' and 'field' in mxds.coords:
                        coords['field'] = ('time', mxds.coords['field'].values[
                            data[chunk_time_changes]])
                    elif col == 'SCAN_NUMBER':
                        coords['scan'] = ('time', data[chunk_time_changes])
                    elif col == 'INTERVAL':
                        coords['interval'] = ('time', data[chunk_time_changes])
                    elif col == 'PROCESSOR_ID' and 'processor' in mxds.coords:
                        coords['processor'] = ('time',
                                               mxds.coords['processor'].values[
                                                   data[chunk_time_changes]])
                    elif col == 'OBSERVATION_ID' and 'observation' in mxds.coords:
                        coords['observation'] = (
                            'time', mxds.coords['observation'].values[
                                data[chunk_time_changes]])
                    elif col == 'STATE_ID' and 'state' in mxds.coords:
                        coords['state'] = ('time', mxds.coords['state'].values[
                            data[chunk_time_changes]])
                    else:
                        fulldata = np.full((len(chunk), n_baseline),
                                           np.nan,
                                           dtype=data.dtype)
                        if col == 'FLAG_ROW':
                            fulldata = np.ones((len(chunk), n_baseline),
                                               dtype=data.dtype)
                        fulldata[time_idxs[idx_range] - chunk[0],
                                 baseline_idxs[idx_range]] = data
                        chunkdata[col] = xarray.DataArray(
                            fulldata, dims=['time', 'baseline'])

                elif (data.ndim == 2) and (data.shape[1] == n_pol):
                    fulldata = np.full((len(chunk), n_baseline, n_pol),
                                       np.nan,
                                       dtype=data.dtype)
                    fulldata[time_idxs[idx_range] - chunk[0],
                             baseline_idxs[idx_range], :] = data
                    chunkdata[col] = xarray.DataArray(
                        fulldata, dims=['time', 'baseline', 'pol'])

                elif (data.ndim == 2) and (data.shape[1] == n_chan):
                    fulldata = np.full((len(chunk), n_baseline, n_chan),
                                       np.nan,
                                       dtype=data.dtype)
                    fulldata[time_idxs[idx_range] - chunk[0],
                             baseline_idxs[idx_range], :] = data
                    chunkdata[col] = xarray.DataArray(
                        fulldata, dims=['time', 'baseline', 'chan'])

                elif data.ndim == 3:
                    assert (data.shape[1] == n_chan) & (
                        data.shape[2]
                        == n_pol), 'Column dimensions not correct'
                    if col == "FLAG":
                        fulldata = np.ones(
                            (len(chunk), n_baseline, n_chan, n_pol),
                            dtype=data.dtype)
                    else:
                        fulldata = np.full(
                            (len(chunk), n_baseline, n_chan, n_pol),
                            np.nan,
                            dtype=data.dtype)
                    fulldata[time_idxs[idx_range] - chunk[0],
                             baseline_idxs[idx_range], :, :] = data
                    chunkdata[col] = xarray.DataArray(
                        fulldata, dims=['time', 'baseline', 'chan', 'pol'])

            x_dataset = xarray.Dataset(chunkdata, coords=coords).chunk({
                'time':
                ddi_chunk_shape[0],
                'baseline':
                ddi_chunk_shape[1],
                'chan':
                ddi_chunk_shape[2],
                'pol':
                ddi_chunk_shape[3],
                'uvw_index':
                None
            })

            if (not nofile) and (cc == 0):
                encoding = dict(
                    zip(list(x_dataset.data_vars),
                        cycle([{
                            'compressor': compressor
                        }])))
                x_dataset.to_zarr(outfile + '/' + str(ddi),
                                  mode='w',
                                  encoding=encoding,
                                  consolidated=True)
            elif not nofile:
                x_dataset.to_zarr(outfile + '/' + str(ddi),
                                  mode='a',
                                  append_dim='time',
                                  compute=True,
                                  consolidated=True)

        # Add non dimensional auxiliary coordinates and attributes
        aux_coords.update({'time': unique_times})
        aux_dataset = xarray.Dataset(coords=aux_coords, attrs=meta_attrs)
        if nofile:
            x_dataset = xarray.merge([x_dataset, aux_dataset]).assign_attrs(
                meta_attrs)  # merge seems to drop attrs
        else:
            aux_dataset.to_zarr(outfile + '/' + str(ddi),
                                mode='a',
                                compute=True,
                                consolidated=True)
            x_dataset = xarray.open_zarr(outfile + '/' + str(ddi))

        xds_list += [x_dataset]
        ms_ddi.close()
        print('Completed ddi', ddi, ' process time ', time.time() - start_ddi)
        print('**********************************')

    return xds_list
示例#47
0
 def test_merge_dicts_dims(self):
     actual = xr.merge([{'y': ('x', [13])}, {'x': [12]}])
     expected = xr.Dataset({'x': [12], 'y': ('x', [13])})
     assert actual.identical(expected)
示例#48
0
def rinexnav3(fn: Union[TextIO, str, Path],
              use: Sequence[str] = None,
              tlim: Sequence[datetime] = None) -> xarray.Dataset:
    """
    Reads RINEX 3.x NAV files
    Michael Hirsch, Ph.D.
    SciVision, Inc.
    http://www.gage.es/sites/default/files/gLAB/HTML/SBAS_Navigation_Rinex_v3.01.html

    The "eof" stuff is over detection of files that may or may not have a trailing newline at EOF.
    """
    if isinstance(fn, (str, Path)):
        fn = Path(fn).expanduser()

    svs = []
    raws = []
    svtypes: List[str] = []
    fields: Dict[str, List[str]] = {}
    times: List[datetime] = []

    with opener(fn) as f:
        header = navheader3(f)
        # %% read data
        for line in f:
            if line.startswith("\n"):  # EOF
                break

            try:
                time = _time(line)
            except ValueError:  # blank or garbage line
                continue

            if tlim is not None:
                if time < tlim[0] or time > tlim[1]:
                    _skip(f, Nl[line[0]])
                    continue
                # not break due to non-monotonic NAV files

            sv = line[:3]
            if use is not None and not sv[0] in use:
                _skip(f, Nl[sv[0]])
                continue

            times.append(time)
            # %% SV types
            field = _newnav(line, sv)

            if len(svtypes) == 0:
                svtypes.append(sv[0])
            elif sv[0] != svtypes[-1]:
                svtypes.append(sv[0])

            if not sv[0] in fields:
                fields[svtypes[-1]] = field

            svs.append(sv)
            # %% get the data as one big long string per SV, unknown # of lines per SV
            raw = line[23:80]  # NOTE: 80, files put data in the last column!

            for _, ln in zip(range(Nl[sv[0]]), f):
                raw += ln[STARTCOL3:80]
            # one line per SV
            raws.append(raw.replace("D", "E").replace("\n", ""))

    # %% parse
    # NOTE: must be 'ns' or .to_netcdf will fail!
    t = np.array([np.datetime64(t, "ns") for t in times])
    nav = xarray.Dataset({}, coords={"time": [], "sv": []})
    svu = sorted(set(svs))

    for sv in svu:
        svi = np.array([i for i, s in enumerate(svs) if s == sv])

        check = np.array([True] * t[svi].size)
        duplicate = True
        sv_copies = 0
        while duplicate:  # process until there are no more duplicate times
            tu, iu = np.unique(t[svi][check], return_index=True)
            duplicate = tu.size != t[svi][check].size

            cf = _sparefields(fields[sv[0]], sv[0], raws[svi[0]])
            gi = [
                i for i, c in enumerate(cf)
                if not c.startswith(("spare", "FitIntvl"))
            ]
            darr = np.empty((svi.size, len(gi)))

            for j, i in enumerate(svi):
                # darr[j, :] = np.genfromtxt(io.BytesIO(raws[i].encode('ascii')), delimiter=Lf)
                try:
                    darr[j, :] = [
                        float(raws[i][Lf * k:Lf * (k + 1)]) for k in gi
                    ]
                except ValueError:
                    logging.info(f"malformed line for {sv}")
                    darr[j, :] = np.nan
            # %% discard duplicated times

            darr = darr[check, :][iu, :]

            dsf = {}
            for (i, d) in zip(gi, darr.T):
                if sv[0] in ("R", "S") and cf[i] in (
                        "X",
                        "dX",
                        "dX2",
                        "Y",
                        "dY",
                        "dY2",
                        "Z",
                        "dZ",
                        "dZ2",
                ):
                    d *= 1000  # km => m

                dsf[cf[i]] = (("time", "sv"), d[:, None])

            svv = sv if not sv_copies else sv + f"_{sv_copies}"
            if len(nav) == 0:
                nav = xarray.Dataset(dsf, coords={"time": tu, "sv": [svv]})
            else:
                nav = xarray.merge(
                    (nav, xarray.Dataset(dsf, coords={
                        "time": tu,
                        "sv": [svv]
                    })))

            sv_copies += 1
            check[np.arange(check.size)[check][iu]] = False

    # %% patch SV names in case of "G 7" => "G07"
    nav = nav.assign_coords(
        sv=[s.replace(" ", "0") for s in nav.sv.values.tolist()])
    # %% other attributes

    # Add ionospheric correction coefficients if exist.
    if "IONOSPHERIC CORR" in header:
        corr = header["IONOSPHERIC CORR"]
        if "GPSA" in corr and "GPSB" in corr:
            nav.attrs["ionospheric_corr_GPS"] = np.hstack(
                (corr["GPSA"], corr["GPSB"]))
        if "GAL" in corr:
            nav.attrs["ionospheric_corr_GAL"] = corr["GAL"]
        if "QZSA" in corr and "QZSB" in corr:
            nav.attrs["ionospheric_corr_QZS"] = np.hstack(
                (corr["QZSA"], corr["QZSB"]))
        if "BDSA" in corr and "BDSB" in corr:
            nav.attrs["ionospheric_corr_BDS"] = np.hstack(
                (corr["BDSA"], corr["BDSB"]))
        if "IRNA" in corr and "IRNB" in corr:
            nav.attrs["ionospheric_corr_IRN"] = np.hstack(
                (corr["IRNA"], corr["IRNB"]))

    nav.attrs["version"] = header["version"]
    nav.attrs["svtype"] = svtypes
    nav.attrs["rinextype"] = "nav"
    if isinstance(fn, Path):
        nav.attrs["filename"] = fn.name

    return nav
示例#49
0
 def test_merge_dicts_simple(self):
     actual = xr.merge([{'foo': 0}, {'bar': 'one'}, {'baz': 3.5}])
     expected = xr.Dataset({'foo': 0, 'bar': 'one', 'baz': 3.5})
     assert actual.identical(expected)
示例#50
0
                       or np.isnan(zdnp).sum() != 0):
                        print("missing value exists in tc2np")
                    else:
                        lonnp = lonin.reshape(1, -1)
                        unp, vnp = librotate.xyzd2uv(xdnp, ydnp, zdnp, lonnp)
                        #missing value
                        if (np.isnan(unp).sum() != 0
                                or np.isnan(vnp).sum() != 0):
                            print("missing value exists in xyzd2uv")
                        else:
                            udata = xr.DataArray(unp.reshape(1,1,nlat,nlon),\
                                    [('time',pd.date_range(date,periods=1)),\
                                     ('level',np.array([lev[l]])),\
                                     ('latitude',latin),('longitude',lonin)],\
                                                 attrs=attrs_u,name=uname)
                            vdata = xr.DataArray(vnp.reshape(1,1,nlat,nlon),\
                                    [('time',pd.date_range(date,periods=1)),\
                                     ('level',np.array([lev[l]])),\
                                     ('latitude',latin),('longitude',lonin)],\
                                                 attrs=attrs_v,name=vname)
                        print(udata)
                        print(vdata)
                        daout.append(udata)
                        daout.append(vdata)

        da_np.append(xr.merge(daout))
data_np = xr.merge(da_np)
print(data_np)

data_np.to_netcdf(outdir / outnc, 'w')
示例#51
0
 def test_merge_dicts_simple(self):
     actual = xr.merge([{'foo': 0}, {'bar': 'one'}, {'baz': 3.5}])
     expected = xr.Dataset({'foo': 0, 'bar': 'one', 'baz': 3.5})
     assert actual.identical(expected)
示例#52
0
 def load(self):  # load raw_data ie array of each images
     xr.merge([xr.open_dataset(f) for f in glob.glob(self.path + '/*.nc')])  # merge different files from the given path
     ds = xr.open_mfdataset(self.path + '/*.nc')  # load the file as dataset
     self.raw_data = np.array(ds.variables[self.data_key])
     self.dataset = [self.raw_data]
示例#53
0
 def test_merge_error(self):
     ds = xr.Dataset({'x': 0})
     with pytest.raises(xr.MergeError):
         xr.merge([ds, ds + 1])
grid_all_tiles = ecco.load_all_tiles_from_netcdf(data_dir,
                                                 var,
                                                 var_type,
                                                 less_output=True)

# Load all tiles of SSH
data_dir = ECCO_dir + '/nctiles_monthly/SSH/'
var = 'SSH'
var_type = 'c'
ssh_all_tiles = ecco.load_all_tiles_from_netcdf(data_dir,
                                                var,
                                                var_type,
                                                less_output=True)

# minimize the metadata (optional)
data = xr.merge([ssh_all_tiles, grid_all_tiles])

# ### Saving a `Dataset`
#
# Now that we've loaded *ssh_all_tiles*, let's save it in the *SSH* file directory.

# In[2]:

new_filename = data_dir + 'data_all_tiles.nc'
print 'saving to ', new_filename

data.to_netcdf(path=new_filename)
print 'finished saving'

# Now let's create a new `Dataset` that only including *SSH* and some grid parameter variables that are on the same 'c' grid points as *SSH*.
#
示例#55
0
 def test_merge_dataarray_unnamed(self):
     data = xr.DataArray([1, 2], dims='x')
     with raises_regex(
             ValueError, 'without providing an explicit name'):
         xr.merge([data])
示例#56
0
    #    mod_downscale_anoms = downscale_analog_anoms_noscale(ds_data, win_masks)
    #    mod_downscale = downscale_analog_noscale(ds_data, win_masks)

    sgrid_mod_downscale_anoms = downscale_analog_anoms_sgrid(
        ds_data, win_masks)
    sgrid_mod_downscale = downscale_analog_nonzero_scale(ds_data, win_masks)

    da_obsc = xr.open_dataset(fpath_prcp_obsc).PCP.load()
    da_obsc['time'] = pd.to_datetime(da_obsc.time.to_pandas().astype(np.str),
                                     format='%Y%m%d',
                                     errors='coerce')
    mod_downscale_cg = da_obsc.loc[downscale_start_year:downscale_end_year]

    mod_downscale_anoms.name = 'mod_d_anoms'
    mod_downscale.name = 'mod_d'
    sgrid_mod_downscale_anoms.name = 'mod_d_anoms_sgrid'
    sgrid_mod_downscale.name = 'mod_d_sgrid'
    mod_downscale_cg.name = 'mod_d_cg'

    obs = ds_data.obs.loc[downscale_start_year:downscale_end_year]

    ds_d = xr.merge([
        obs, mod_downscale_cg, mod_downscale, mod_downscale_anoms,
        sgrid_mod_downscale, sgrid_mod_downscale_anoms
    ])
    ds_d = ds_d.drop('month')
    ds_d.to_netcdf(
        os.path.join(
            esd.cfg.data_root, 'downscaling', 'downscale_tests_%s_%s.nc' %
            (downscale_start_year, downscale_end_year)))
示例#57
0
    def _handler(self, request, response):
        write_log(self, "Processing started", process_step="start")

        # --- Process inputs ---
        geo_url = request.inputs["resource"][0].url
        index_dim = request.inputs['index_dim'][0].data
        feat_dim = request.inputs['feat_dim'][0].data
        squeeze = request.inputs['squeeze'][0].data
        grid_map = "longitude_latitude"  # request.inputs['grid_mapping'][0].data

        # Open URL
        gdf = gpd.read_file(geo_url)
        write_log(self, "Geoseries downloaded", process_step="downloaded")

        # Try casting all columns
        gdf = _maybe_cast(gdf)

        # Set index and convert to xr
        ds = gdf.set_index(index_dim).to_xarray()

        # Reshape geometries
        ds = cfgeo.reshape_unique_geometries(ds)

        # Convert shapely objects to CF-style
        coords = cfgeo.shapely_to_cf(ds.geometry, grid_mapping=grid_map)
        if coords.features.size == coords.node.size:
            # Then it's only single points, we can drop 'node'
            coords = coords.drop_dims('node')

        ds = xr.merge([ds.drop_vars('geometry'), coords])

        # feat dim
        if feat_dim:
            feat = _maybe_squeeze(ds[feat_dim], index_dim)
            if feat.ndim == 2:
                raise ValueError(
                    "'feat_dim' was given but it cannot be squeezed along the index dimension."
                )
            ds[feat_dim] = feat
            ds = ds.drop_vars('features').rename(
                features=feat_dim).set_coords(feat_dim)

        # squeeze
        if squeeze:
            for dim in [feat_dim, index_dim]:
                for name, var in ds.data_vars.items():
                    if dim in var.dims:
                        try:
                            ds[name] = _maybe_squeeze(var, dim)
                        except KeyError:
                            print(ds)
                            raise

        write_log(self,
                  "Geoseries converted to CF dataset",
                  process_step="converted")

        host = urlparse(geo_url).netloc
        if host is None:
            source = "geospatial series data"
        else:
            source = f"data downloaded from {host}"
        ds.attrs["history"] = update_history(
            f"Converted {source} to a CF-compliant Dataset.", ds)

        # Write to disk
        filename = valid_filename(
            single_input_or_none(request.inputs, "output_name") or "geoseries")
        output_file = Path(self.workdir) / f"{filename}.nc"
        dataset_to_netcdf(ds, output_file)

        # Fill response
        response.outputs["output"].file = str(output_file)
        response.outputs["output_log"].file = str(log_file_path(self))
示例#58
0
def merge_data_arrays(*DataArrays):
    das = [da.name for da in DataArrays]
    print(f"Merging data: {das}")
    ds = xr.merge([*DataArrays])
    return ds
示例#59
0
 def duplicate_and_merge(array):
     return xr.merge([array, array.rename('bar')]).to_array()
示例#60
0
    def test_train(self, tmp_path, use_pred_months, experiment, monthly_agg):

        import xgboost as xgb

        x, _, _ = _make_dataset(size=(5, 5), const=True)
        x_static, _, _ = _make_dataset(size=(5, 5), add_times=False)
        y = x.isel(time=[-1])

        x_add1, _, _ = _make_dataset(size=(5, 5), const=True, variable_name="precip")
        x_add2, _, _ = _make_dataset(size=(5, 5), const=True, variable_name="temp")
        x = xr.merge([x, x_add1, x_add2])

        norm_dict = {
            "VHI": {"mean": 0, "std": 1},
            "precip": {"mean": 0, "std": 1},
            "temp": {"mean": 0, "std": 1},
        }

        static_norm_dict = {"VHI": {"mean": 0.0, "std": 1.0}}

        test_features = tmp_path / f"features/{experiment}/train/1980_1"
        test_features.mkdir(parents=True)
        pred_features = tmp_path / f"features/{experiment}/test/1980_1"
        pred_features.mkdir(parents=True)
        static_features = tmp_path / f"features/static"
        static_features.mkdir(parents=True)

        with (tmp_path / f"features/{experiment}/normalizing_dict.pkl").open("wb") as f:
            pickle.dump(norm_dict, f)

        with (tmp_path / f"features/static/normalizing_dict.pkl").open("wb") as f:
            pickle.dump(static_norm_dict, f)

        x.to_netcdf(test_features / "x.nc")
        x.to_netcdf(pred_features / "x.nc")
        y.to_netcdf(test_features / "y.nc")
        y.to_netcdf(pred_features / "y.nc")
        x_static.to_netcdf(static_features / "data.nc")

        model = GBDT(
            tmp_path,
            include_pred_month=use_pred_months,
            experiment=experiment,
            include_monthly_aggs=monthly_agg,
            normalize_y=False,
        )
        model.train()

        assert (
            type(model.model) == xgb.XGBRegressor
        ), f"Model attribute not a gradient boosted regressor!"

        test_arrays_dict, preds_dict = model.predict()
        assert (
            test_arrays_dict["1980_1"]["y"].size == preds_dict["hello"].shape[0]
        ), "Expected length of test arrays to be the same as the predictions"

        # test saving the model outputs
        model.evaluate(save_preds=True)

        save_path = model.data_path / "models" / experiment / "gbdt"
        assert (save_path / "preds_1980_1.nc").exists()
        assert (save_path / "results.json").exists()

        pred_ds = xr.open_dataset(save_path / "preds_1980_1.nc")
        assert np.isin(["lat", "lon", "time"], [c for c in pred_ds.coords]).all()
        assert y.time == pred_ds.time