def test_concat(self): # TODO: simplify and split this test case # drop the third dimension to keep things relatively understandable data = create_test_data() for k in list(data): if 'dim3' in data[k].dims: del data[k] split_data = [data.isel(dim1=slice(3)), data.isel(dim1=slice(3, None))] self.assertDatasetIdentical(data, concat(split_data, 'dim1')) def rectify_dim_order(dataset): # return a new dataset with all variable dimensions transposed into # the order in which they are found in `data` return Dataset(dict((k, v.transpose(*data[k].dims)) for k, v in iteritems(dataset.data_vars)), dataset.coords, attrs=dataset.attrs) for dim in ['dim1', 'dim2']: datasets = [g for _, g in data.groupby(dim, squeeze=False)] self.assertDatasetIdentical(data, concat(datasets, dim)) dim = 'dim2' self.assertDatasetIdentical( data, concat(datasets, data[dim])) self.assertDatasetIdentical( data, concat(datasets, data[dim], coords='minimal')) datasets = [g for _, g in data.groupby(dim, squeeze=True)] concat_over = [k for k, v in iteritems(data.coords) if dim in v.dims and k != dim] actual = concat(datasets, data[dim], coords=concat_over) self.assertDatasetIdentical(data, rectify_dim_order(actual)) actual = concat(datasets, data[dim], coords='different') self.assertDatasetIdentical(data, rectify_dim_order(actual)) # make sure the coords argument behaves as expected data.coords['extra'] = ('dim4', np.arange(3)) for dim in ['dim1', 'dim2']: datasets = [g for _, g in data.groupby(dim, squeeze=True)] actual = concat(datasets, data[dim], coords='all') expected = np.array([data['extra'].values for _ in range(data.dims[dim])]) self.assertArrayEqual(actual['extra'].values, expected) actual = concat(datasets, data[dim], coords='different') self.assertDataArrayEqual(data['extra'], actual['extra']) actual = concat(datasets, data[dim], coords='minimal') self.assertDataArrayEqual(data['extra'], actual['extra']) # verify that the dim argument takes precedence over # concatenating dataset variables of the same name dim = (2 * data['dim1']).rename('dim1') datasets = [g for _, g in data.groupby('dim1', squeeze=False)] expected = data.copy() expected['dim1'] = dim self.assertDatasetIdentical(expected, concat(datasets, dim))
def test_concat_size0(self): data = create_test_data() split_data = [data.isel(dim1=slice(0, 0)), data] actual = concat(split_data, 'dim1') self.assertDatasetIdentical(data, actual) actual = concat(split_data[::-1], 'dim1') self.assertDatasetIdentical(data, actual)
def multi_concat(results, dims): """Concatenate a nested list of xarray objects along several dimensions. """ if len(dims) == 1: return xr.concat(results, dim=dims[0]) else: return xr.concat([multi_concat(sub_results, dims[1:]) for sub_results in results], dim=dims[0])
def test_concat_encoding(self): # Regression test for GH1297 ds = Dataset({'foo': (['x', 'y'], np.random.random((2, 3))), 'bar': (['x', 'y'], np.random.random((2, 3)))}, {'x': [0, 1]}) foo = ds['foo'] foo.encoding = {"complevel": 5} ds.encoding = {"unlimited_dims": 'x'} assert concat([foo, foo], dim="x").encoding == foo.encoding assert concat([ds, ds], dim="x").encoding == ds.encoding
def test_concat_do_not_promote(self): # GH438 objs = [Dataset({"y": ("t", [1])}, {"x": 1}), Dataset({"y": ("t", [2])}, {"x": 1})] expected = Dataset({"y": ("t", [1, 2])}, {"x": 1, "t": [0, 0]}) actual = concat(objs, "t") self.assertDatasetIdentical(expected, actual) objs = [Dataset({"y": ("t", [1])}, {"x": 1}), Dataset({"y": ("t", [2])}, {"x": 2})] with self.assertRaises(ValueError): concat(objs, "t", coords="minimal")
def test_concat_coords(self): data = Dataset({"foo": ("x", np.random.randn(10))}) expected = data.assign_coords(c=("x", [0] * 5 + [1] * 5)) objs = [data.isel(x=slice(5)).assign_coords(c=0), data.isel(x=slice(5, None)).assign_coords(c=1)] for coords in ["different", "all", ["c"]]: actual = concat(objs, dim="x", coords=coords) self.assertDatasetIdentical(expected, actual) for coords in ["minimal", []]: with self.assertRaisesRegexp(ValueError, "not equal across"): concat(objs, dim="x", coords=coords)
def test_concat_constant_index(self): # GH425 ds1 = Dataset({"foo": 1.5}, {"y": 1}) ds2 = Dataset({"foo": 2.5}, {"y": 1}) expected = Dataset({"foo": ("y", [1.5, 2.5]), "y": [1, 1]}) for mode in ["different", "all", ["foo"]]: actual = concat([ds1, ds2], "y", data_vars=mode) self.assertDatasetIdentical(expected, actual) with self.assertRaisesRegexp(ValueError, "not equal across datasets"): concat([ds1, ds2], "y", data_vars="minimal")
def test_concat_constant_index(self): # GH425 ds1 = Dataset({'foo': 1.5}, {'y': 1}) ds2 = Dataset({'foo': 2.5}, {'y': 1}) expected = Dataset({'foo': ('y', [1.5, 2.5]), 'y': [1, 1]}) for mode in ['different', 'all', ['foo']]: actual = concat([ds1, ds2], 'y', data_vars=mode) self.assertDatasetIdentical(expected, actual) with self.assertRaisesRegexp(ValueError, 'not equal across datasets'): concat([ds1, ds2], 'y', data_vars='minimal')
def test_concat(self): ds = Dataset({"foo": (["x", "y"], np.random.random((10, 20))), "bar": (["x", "y"], np.random.random((10, 20)))}) foo = ds["foo"] bar = ds["bar"] # from dataset array: expected = DataArray(np.array([foo.values, bar.values]), dims=["w", "x", "y"]) actual = concat([foo, bar], "w") self.assertDataArrayEqual(expected, actual) # from iteration: grouped = [g for _, g in foo.groupby("x")] stacked = concat(grouped, ds["x"]) self.assertDataArrayIdentical(foo, stacked) # with an index as the 'dim' argument stacked = concat(grouped, ds.indexes["x"]) self.assertDataArrayIdentical(foo, stacked) actual = concat([foo[0], foo[1]], pd.Index([0, 1])).reset_coords(drop=True) expected = foo[:2].rename({"x": "concat_dim"}) self.assertDataArrayIdentical(expected, actual) actual = concat([foo[0], foo[1]], [0, 1]).reset_coords(drop=True) expected = foo[:2].rename({"x": "concat_dim"}) self.assertDataArrayIdentical(expected, actual) with self.assertRaisesRegexp(ValueError, "not identical"): concat([foo, bar], dim="w", compat="identical") with self.assertRaisesRegexp(ValueError, "not a valid argument"): concat([foo, bar], dim="w", data_vars="minimal")
def test_concat(self): ds = Dataset({'foo': (['x', 'y'], np.random.random((2, 3))), 'bar': (['x', 'y'], np.random.random((2, 3)))}, {'x': [0, 1]}) foo = ds['foo'] bar = ds['bar'] # from dataset array: expected = DataArray(np.array([foo.values, bar.values]), dims=['w', 'x', 'y'], coords={'x': [0, 1]}) actual = concat([foo, bar], 'w') self.assertDataArrayEqual(expected, actual) # from iteration: grouped = [g for _, g in foo.groupby('x')] stacked = concat(grouped, ds['x']) self.assertDataArrayIdentical(foo, stacked) # with an index as the 'dim' argument stacked = concat(grouped, ds.indexes['x']) self.assertDataArrayIdentical(foo, stacked) actual = concat([foo[0], foo[1]], pd.Index([0, 1])).reset_coords(drop=True) expected = foo[:2].rename({'x': 'concat_dim'}) self.assertDataArrayIdentical(expected, actual) actual = concat([foo[0], foo[1]], [0, 1]).reset_coords(drop=True) expected = foo[:2].rename({'x': 'concat_dim'}) self.assertDataArrayIdentical(expected, actual) with self.assertRaisesRegexp(ValueError, 'not identical'): concat([foo, bar], dim='w', compat='identical') with self.assertRaisesRegexp(ValueError, 'not a valid argument'): concat([foo, bar], dim='w', data_vars='minimal')
def test_concat_coords(self): data = Dataset({'foo': ('x', np.random.randn(10))}) expected = data.assign_coords(c=('x', [0] * 5 + [1] * 5)) objs = [data.isel(x=slice(5)).assign_coords(c=0), data.isel(x=slice(5, None)).assign_coords(c=1)] for coords in ['different', 'all', ['c']]: actual = concat(objs, dim='x', coords=coords) self.assertDatasetIdentical(expected, actual) for coords in ['minimal', []]: with self.assertRaisesRegexp(ValueError, 'not equal across'): concat(objs, dim='x', coords=coords)
def add_cyclic(varin,dim='nlon'): '''Add a cyclic point to CESM data. Preserve datatype: xarray''' dimdict = {} dimdict[dim] = 0 if dim == 'nlon': return(xr.concat([varin, varin.isel(nlon=0)], dim='nlon')) elif dim == 'nlat': return(xr.concat([varin, varin.isel(nlat=0)], dim='nlat')) elif dim == 'dim_0': return(xr.concat([varin, varin.isel(dim_0=0)], dim='dim_0')) elif dim == 'dim_1': return(xr.concat([varin, varin.isel(dim_1=0)], dim='dim_1'))
def test_concat_twice(self, create_combined_ids, concat_dim): shape = (2, 3) combined_ids = create_combined_ids(shape) result = _combine_nd(combined_ids, concat_dims=['dim1', concat_dim]) ds = create_test_data partway1 = concat([ds(0), ds(3)], dim='dim1') partway2 = concat([ds(1), ds(4)], dim='dim1') partway3 = concat([ds(2), ds(5)], dim='dim1') expected = concat([partway1, partway2, partway3], dim=concat_dim) assert_equal(result, expected)
def test_auto_combine_2d(self): ds = create_test_data partway1 = concat([ds(0), ds(3)], dim='dim1') partway2 = concat([ds(1), ds(4)], dim='dim1') partway3 = concat([ds(2), ds(5)], dim='dim1') expected = concat([partway1, partway2, partway3], dim='dim2') datasets = [[ds(0), ds(1), ds(2)], [ds(3), ds(4), ds(5)]] result = auto_combine(datasets, concat_dim=['dim1', 'dim2']) assert_equal(result, expected)
def test_concat_do_not_promote(self): # GH438 objs = [Dataset({'y': ('t', [1])}, {'x': 1, 't': [0]}), Dataset({'y': ('t', [2])}, {'x': 1, 't': [0]})] expected = Dataset({'y': ('t', [1, 2])}, {'x': 1, 't': [0, 0]}) actual = concat(objs, 't') self.assertDatasetIdentical(expected, actual) objs = [Dataset({'y': ('t', [1])}, {'x': 1, 't': [0]}), Dataset({'y': ('t', [2])}, {'x': 2, 't': [0]})] with self.assertRaises(ValueError): concat(objs, 't', coords='minimal')
def file_loop(passit): ds = xr.open_dataset(passit) dataset = passit[6:9] ds['tir'].values = ds['tir'] bloblist = [] tirlist = [] lat = ds.lat lon = ds.lon for ids, day in enumerate(ds['tir']): print('id', ids) date = day.time day.values = day / 100 if np.sum(day.values) == 0: continue img, nogood, t_thresh_size, t_thresh_cut, pix_nb = powerBlob_utils.filter_img(day.values, 5) power = util.waveletT(img, dataset='METEOSAT5K_vera') power_out = powerBlob_utils.find_scales_dominant(power, nogood, dataset=dataset) if power_out is None: continue new_savet = (day.values*100).astype(np.int16) bloblist.append(xr.DataArray(power_out.astype(np.int16), coords={'time': date, 'lat': lat, 'lon': lon}, dims=['lat', 'lon'])) # [np.newaxis, :]) tirlist.append(xr.DataArray(new_savet, coords={'time': date, 'lat': lat, 'lon': lon}, dims=['lat', 'lon'])) ds_mfg = xr.Dataset() ds_mfg['blobs'] = xr.concat(bloblist, 'time') ds_mfg['tir'] = xr.concat(tirlist, 'time') ds_mfg.sel(lat=slice(5, 12), lon=slice(-13, 13)) savefile = passit.replace('cores_', 'coresPower_') try: os.remove(savefile) except OSError: pass comp = dict(zlib=True, complevel=5) enc = {var: comp for var in ds_mfg.data_vars} ds_mfg.to_netcdf(path=savefile, mode='w', encoding=enc, format='NETCDF4') print('Saved ' + savefile)
def add_to_slice(da, dim, sl, value): # split array into before, middle and after (if slice is the # beginning or end before or after will be empty) before = da[{dim: slice(0, sl)}] middle = da[{dim: sl}] after = da[{dim: slice(sl+1, None)}] if sl < -1: raise RuntimeError('slice can not be smaller value than -1') elif sl == -1: da_new = xr.concat([before, middle+value], dim=dim) else: da_new = xr.concat([before, middle+value, after], dim=dim) # then add 'value' to middle and concatenate again return da_new
def pp_tile(config, timestamp, coordinate_templates, drop_list, tile): """ Post-process a rectangular tile of cells. **Arguments:** * config A `~scmtiles.config.SCMTilesConfig` instance describing the run being post-processed. * timestamp A string timestamp used as part of the filename for the cell output files. * coordiate_templates A dictionary mapping coordinate names to xarray coordinate objects, as returned from `load_coorindate_templates`. This is used to lookup the latitude and longitude of the cells from their indices. * tile A `~scmtiles.grid_manager.RectangularTile` instance describing the tile to process. **Returns:** * (tile_ds, filepaths) An `xarray.Dataset` representing the tile, and a list of paths to the files that were loaded to form the tile. """ grid_rows = OrderedDict() filepaths = [] for cell in tile.cells(): cell_ds, cell_filepath = pp_cell(cell, timestamp, coordinate_templates, drop_list, config) try: grid_rows[cell.y_global].append(cell_ds) except KeyError: grid_rows[cell.y_global] = [cell_ds] filepaths.append(cell_filepath) for key, row in grid_rows.items(): grid_rows[key] = xr.concat(row, dim=config.xname) if len(grid_rows) > 1: tile_ds = xr.concat(grid_rows.values(), dim=config.yname) else: tile_ds, = grid_rows.values() logger = logging.getLogger('PP') logger.info('processing of tile #{} completed'.format(tile.id)) return tile_ds, filepaths
def open_mfxr(files, dim='TIME', transform_func=None): """ Load multiple MAR files into a single xarray object, performing some aggregation first to make this computationally feasible. E.g. select a single datarray to examine. # you might also use indexing operations like .sel to subset datasets comb = read_netcdfs('MAR*.nc', dim='TIME', transform_func=lambda ds: ds.AL) Based on http://xray.readthedocs.io/en/v0.7.1/io.html#combining-multiple-files See also http://xray.readthedocs.io/en/v0.7.1/dask.html """ def process_one_path(path): ds = open_xr(path,chunks={'TIME':366}) # transform_func should do some sort of selection or # aggregation if transform_func is not None: ds = transform_func(ds) # load all data from the transformed dataset, to ensure we can # use it after closing each original file return ds paths = sorted(glob(files)) datasets = [process_one_path(p) for p in paths] combined = xr.concat(datasets, dim) return combined
def test_update_add_data_to_coords(self): test_array = self.array.copy() test_array['time'] = [datetime.datetime.utcnow(), ] concatenated_array = xr.concat([self.array, test_array], dim='time') self.array.pp.grid = self.grid updated_array = self.array.pp.update(test_array) np.testing.assert_equal(updated_array.values, concatenated_array.values)
def test_concatenate(): """make sure we can concatenate easily time series x - test it with rec array as one of the coords. This fails for xarray > 0.7. See https://github.com/pydata/xarray/issues/1434 for details. """ p1 = np.array([('John', 180), ('Stacy', 150), ('Dick',200)], dtype=[('name', '|S256'), ('height', int)]) p2 = np.array([('Bernie', 170), ('Donald', 250), ('Hillary',150)], dtype=[('name', '|S256'), ('height', int)]) data = np.arange(50, 80, 1, dtype=np.float) dims = ['measurement', 'participant'] ts1 = TimeSeriesX.create(data.reshape(10, 3), None, dims=dims, coords={ 'measurement': np.arange(10), 'participant': p1, 'samplerate': 1 }) ts2 = TimeSeriesX.create(data.reshape(10, 3)*2, None, dims=dims, coords={ 'measurement': np.arange(10), 'participant': p2, 'samplerate': 1 }) combined = xr.concat((ts1, ts2), dim='participant') assert isinstance(combined, TimeSeriesX) assert (combined.participant.data['height'] == np.array([180, 150, 200, 170, 250, 150])).all() assert (combined.participant.data['name'] == np.array(['John', 'Stacy', 'Dick', 'Bernie', 'Donald', 'Hillary'])).all()
def month_count(): years = list(range(1983, 2018)) msg_folder = cnst.GRIDSAT fname = 'aggs/gridsat_WA_-70_monthly_count.nc' #65_monthly_count_-40base_15-21UTC_1000km2.nc' if not os.path.isfile(msg_folder + fname): da = None for y in years: y = str(y) da1 = xr.open_dataset(cnst.GRIDSAT + 'gridsat_WA_' + y + '.nc') # _-40_1000km2_15-21UTC print('Doing ' + y) da1['tir'].values = da1['tir'].values/100 da1['tir'] = da1['tir'].where((da1['tir'] <= -70) & (da1['tir'] >= -108)) #-65 da1['tir'].values[da1['tir'].values < -70] = 1 da1 = da1.resample(time='m').sum('time') try: da = xr.concat([da, da1], 'time') except TypeError: da = da1.copy() enc = {'tir': {'complevel': 5, 'zlib': True}} da.to_netcdf(msg_folder + fname, encoding=enc)
def get_seasonal_clim(self, start_year=2002, end_year=2010, season_to_months:dict=None, vname:str= "sst", skip_months_of_edge_winters=False): result = OrderedDict() for sname, months in season_to_months.items(): data_arrays = [] for y in range(start_year, end_year + 1): # -- for month in months: # skip December of the last winter and (Jan, Feb) of the first winter if skip_months_of_edge_winters: if sname.lower() in ["winter", "djf"]: if month == 12 and y == end_year: continue if (month == 1 or month == 2) and y == start_year: continue for day in range(1, calendar.monthrange(y, month)[1] + 1): file_url = self.get_url_for(year=y, month=month, day=day) print("Opening {}".format(file_url)) with xr.open_dataset(file_url) as ds: data_arrays.append(ds[vname][0, 0, :, :].load()) result[sname] = xr.concat(data_arrays, dim="time").mean(dim="time") return result
def loop(y): out = cnst.local_data + 'GRIDSAT/MCS18/' infolder = cnst.local_data + 'GRIDSAT/www.ncei.noaa.gov/data/geostationary-ir-channel-brightness-temperature-gridsat-b1/access/' filename = 'gridsat_WA_-40_1000km2_15-21UTC' + str(y) + '.nc' da = None if os.path.isfile(out + filename): return files = glob.glob(infolder + str(y) + '/GRIDSAT-AFRICA_CP*.nc') files.sort() for f in files: print('Doing ' + f) df = xr.open_dataset(f) if (df['time.hour']<15) | (df['time.hour']>21): continue df.rename({'irwin_cdr': 'tir'}, inplace=True) df['tir'].values = df['tir'].values-273.15 labels, goodinds = ua.blob_define(df['tir'].values, -40, minmax_area=[16, 25000], max_area=None) # 7.7x7.7km = 64km2 per pix in gridsat? df['tir'].values[labels == 0] = 0 df['tir'].values[df['tir'].values < -110] = 0 df['tir'].values = (np.round(df['tir'].values, decimals=2)*100).astype(np.int16) try: da = xr.concat([da, df], dim='time') except TypeError: da = df.copy() enc = {'tir': {'complevel': 5, 'shuffle': True, 'zlib': True}} da.to_netcdf(out + filename, encoding=enc) da.close()
def fetch_full_san_data(stream_key, time_range, location_metadata=None): """ Given a time range and stream key. Genereate all data in the inverval using data from the SAN. :param stream_key: :param time_range: :return: """ if location_metadata is None: location_metadata = get_san_location_metadata(stream_key, time_range) # get which bins we can gather data from ref_des_dir, dir_string = get_SAN_directories(stream_key, split=True) if not os.path.exists(ref_des_dir): log.warning("Reference Designator does not exist in offloaded DataSAN") return None data = [] next_index = 0 for time_bin in location_metadata.bin_list: direct = dir_string.format(time_bin) if os.path.exists(direct): # get data from all of the deployments deployments = os.listdir(direct) for deployment in deployments: full_path = os.path.join(direct, deployment) if os.path.isdir(full_path): new_data = get_deployment_data(full_path, stream_key.stream_name, -1, time_range, index_start=next_index) if new_data is not None: data.append(new_data) # Keep track of indexes so they are unique in the final dataset next_index += len(new_data['index']) if not data: return None return xr.concat(data, dim='index')
def month_mean_hov(): years = list(range(1983,2018)) msg_folder = cnst.GRIDSAT fname = 'aggs/gridsat_WA_-50_monthly_mean.nc' hov_box = None for y in years: y = str(y) da1 = xr.open_dataset(cnst.GRIDSAT + 'gridsat_WA_-50_' + y + '.nc') print('Doing ' + y) da1['tir'] = da1['tir'].where((da1['tir'] <= -50) & (da1['tir'] >= -108) ) WA_box = [-10,10,4.5,20] SA_box = [25,33,-28,-10] hov_boxed = da1['tir'].sel(lat=slice(SA_box[2],SA_box[3]), lon=slice(SA_box[0],SA_box[1])).resample(time='m').mean(['lon','time']) out = xr.DataArray(hov_boxed.values, coords={'month':hov_boxed['time.month'].values, 'lat':hov_boxed.lat}, dims=['month', 'lat']) try: hov_box = xr.concat([hov_box, out], 'year') except TypeError: hov_box = out.copy() hov_box.year.values = hov_box.year.values+years[0] hov_box.to_netcdf(msg_folder + 'aggs/SAbox_meanT-50_hov_5000km2.nc')
def _get_Laskar_data(verbose=True): longorbit = {} sources = {} pandas_kwargs = {'delim_whitespace':True, 'header':None, 'index_col':0, 'names':['kyear','ecc','obliquity','long_peri'],} for time in filenames: local_path = os.path.join(os.path.dirname(__file__), "data", filenames[time]) remote_path = base_url + filenames[time] if time is 'future': pandas_kwargs['skiprows'] = 1 # first row is kyear=0, redundant longorbit[time], path = load_data_source(local_path=local_path, remote_source_list=[remote_path], open_method = pd.read_csv, open_method_kwargs=pandas_kwargs, verbose=verbose) sources[time] = path xlongorbit = {} for time in ['past', 'future']: # Cannot convert to float until we replace the D notation with E for floating point numbers longorbit[time].replace(to_replace='D', value='E', regex=True, inplace=True) xlongorbit[time] = xr.Dataset() xlongorbit[time]['ecc'] = xr.DataArray(pd.to_numeric(longorbit[time]['ecc'])) for field in ['obliquity', 'long_peri']: xlongorbit[time][field] = xr.DataArray(np.rad2deg(pd.to_numeric(longorbit[time][field]))) longorbit = xr.concat([xlongorbit['past'], xlongorbit['future']], dim='kyear') # add 180 degrees to long_peri (see lambda definition, Berger 1978 Appendix) longorbit['long_peri'] += 180. longorbit['precession'] = longorbit.ecc*np.sin(np.deg2rad(longorbit.long_peri)) longorbit.attrs['Description'] = 'The Laskar et al. (2004) orbital data table' longorbit.attrs['Citation'] = 'https://doi.org/10.1051/0004-6361:20041335' longorbit.attrs['Source'] = [sources[time] for time in sources] longorbit.attrs['Note'] = 'Longitude of perihelion is defined to be 0 degrees at Northern Vernal Equinox. This differs by 180 degrees from the source files.' return longorbit
def load_nc_file_cell(nc_file, start_year, end_year, lat, lon): ''' Loads in nc files for all years. Parameters ---------- nc_file: <str> netCDF file to load, with {} to be substituted by YYYY start_year: <int> Start year end_year: <int> End year lat: <float> lat of grid cell to extract lon: <float> lon of grid cell to extract Returns ---------- ds_all_years: <xr.Dataset> Dataset of all years ''' list_ds = [] for year in range(start_year, end_year+1): # Load data fname = nc_file.format(year) ds = xr.open_dataset(fname).sel(lat=lat, lon=lon) list_ds.append(ds) # Concat all years ds_all_years = xr.concat(list_ds, dim='time') return ds_all_years
def filter(self): """ Chops session into chunks corresponding to events :return: timeSeriesX object with chopped session """ chop_on_start_offsets_flag = bool(len(self.start_offsets)) if chop_on_start_offsets_flag: start_offsets = self.start_offsets chopping_axis_name = 'start_offsets' chopping_axis_data = start_offsets else: evs = self.events[self.events.eegfile == self.session_data.attrs['dataroot']] start_offsets = evs.eegoffset chopping_axis_name = 'events' chopping_axis_data = evs # samplerate = self.session_data.attrs['samplerate'] samplerate = float(self.session_data['samplerate']) offset_time_array = self.session_data['offsets'] event_chunk_size, start_point_shift = self.get_event_chunk_size_and_start_point_shift( eegoffset=start_offsets[0], samplerate=samplerate, offset_time_array=offset_time_array) event_time_axis = np.arange(event_chunk_size)*(1.0/samplerate)+(self.start_time-self.buffer_time) data_list = [] for i, eegoffset in enumerate(start_offsets): start_chop_pos = np.where(offset_time_array >= eegoffset)[0][0] start_chop_pos += start_point_shift selector_array = np.arange(start=start_chop_pos, stop=start_chop_pos + event_chunk_size) chopped_data_array = self.session_data.isel(time=selector_array) chopped_data_array['time'] = event_time_axis chopped_data_array['start_offsets'] = [i] data_list.append(chopped_data_array) ev_concat_data = xr.concat(data_list, dim='start_offsets') ev_concat_data = ev_concat_data.rename({'start_offsets':chopping_axis_name}) ev_concat_data[chopping_axis_name] = chopping_axis_data attrs = { "start_time": self.start_time, "end_time": self.end_time, "buffer_time": self.buffer_time } ev_concat_data['samplerate'] = samplerate return TimeSeriesX.create(ev_concat_data, samplerate, attrs=attrs)
def test_concat_multiindex(self): x = pd.MultiIndex.from_product([[1, 2, 3], ['a', 'b']]) expected = Dataset({'x': x}) actual = concat([expected.isel(x=slice(2)), expected.isel(x=slice(2, None))], 'x') assert expected.equals(actual) assert isinstance(actual.x.to_index(), pd.MultiIndex)
def summary( data, var_names=None, fmt="wide", round_to=2, include_circ=None, stat_funcs=None, extend=True, credible_interval=0.94, batches=None, ): """Create a data frame with summary statistics. Parameters ---------- data : obj Any object that can be converted to an az.InferenceData object Refer to documentation of az.convert_to_dataset for details var_names : list Names of variables to include in summary include_circ : bool Whether to include circular statistics fmt : {'wide', 'long', 'xarray'} Return format is either pandas.DataFrame {'wide', 'long'} or xarray.Dataset {'xarray'}. round_to : int Number of decimals used to round results. Defaults to 2. stat_funcs : None or list A list of functions used to calculate statistics. By default, the mean, standard deviation, simulation standard error, and highest posterior density intervals are included. The functions will be given one argument, the samples for a variable as an nD array, The functions should be in the style of a ufunc and return a single number. For example, `np.sin`, or `scipy.stats.var` would both work. extend : boolean If True, use the statistics returned by `stat_funcs` in addition to, rather than in place of, the default statistics. This is only meaningful when `stat_funcs` is not None. credible_interval : float, optional Credible interval to plot. Defaults to 0.94. This is only meaningful when `stat_funcs` is None. batches : None or int Batch size for calculating standard deviation for non-independent samples. Defaults to the smaller of 100 or the number of samples. This is only meaningful when `stat_funcs` is None. Returns ------- pandas.DataFrame With summary statistics for each variable. Defaults statistics are: `mean`, `sd`, `hpd_3%`, `hpd_97%`, `mc_error`, `eff_n` and `r_hat`. `eff_n` and `r_hat` are only computed for traces with 2 or more chains. Examples -------- .. code:: ipython >>> az.summary(trace, ['mu']) mean sd mc_error hpd_3 hpd_97 eff_n r_hat mu[0] 0.10 0.06 0.00 -0.02 0.23 487.0 1.00 mu[1] -0.04 0.06 0.00 -0.17 0.08 379.0 1.00 Other statistics can be calculated by passing a list of functions. .. code:: ipython >>> import pandas as pd >>> def trace_sd(x): ... return pd.Series(np.std(x, 0), name='sd') ... >>> def trace_quantiles(x): ... return pd.DataFrame(pd.quantiles(x, [5, 50, 95])) ... >>> az.summary(trace, ['mu'], stat_funcs=[trace_sd, trace_quantiles]) sd 5 50 95 mu[0] 0.06 0.00 0.10 0.21 mu[1] 0.07 -0.16 -0.04 0.06 """ posterior = convert_to_dataset(data, group="posterior") posterior = posterior if var_names is None else posterior[var_names] if batches is None: batches = min([100, posterior.draw.size]) fmt_group = ("wide", "long", "xarray") if not isinstance(fmt, str) or (fmt.lower() not in fmt_group): raise TypeError( "Invalid format: '{}'! Formatting options are: {}".format( fmt, fmt_group)) alpha = 1 - credible_interval metrics = [] metric_names = [] if stat_funcs is not None: for stat_func in stat_funcs: metrics.append( xr.apply_ufunc(_make_ufunc(stat_func), posterior, input_core_dims=(("chain", "draw")))) metric_names.append(stat_func.__name__) if extend: metrics.append(posterior.mean(dim=("chain", "draw"))) metric_names.append("mean") metrics.append(posterior.std(dim=("chain", "draw"))) metric_names.append("sd") metrics.append( xr.apply_ufunc(_make_ufunc(_mc_error), posterior, input_core_dims=(("chain", "draw"), ))) metric_names.append("mc error") metrics.append( xr.apply_ufunc( _make_ufunc(hpd, index=0, credible_interval=credible_interval), posterior, input_core_dims=(("chain", "draw"), ), )) metric_names.append("hpd {:g}%".format(100 * alpha / 2)) metrics.append( xr.apply_ufunc( _make_ufunc(hpd, index=1, credible_interval=credible_interval), posterior, input_core_dims=(("chain", "draw"), ), )) metric_names.append("hpd {:g}%".format(100 * (1 - alpha / 2))) if include_circ: metrics.append( xr.apply_ufunc( _make_ufunc(st.circmean, high=np.pi, low=-np.pi), posterior, input_core_dims=(("chain", "draw"), ), )) metric_names.append("circular mean") metrics.append( xr.apply_ufunc( _make_ufunc(st.circstd, high=np.pi, low=-np.pi), posterior, input_core_dims=(("chain", "draw"), ), )) metric_names.append("circular standard deviation") metrics.append( xr.apply_ufunc( _make_ufunc(_mc_error, circular=True), posterior, input_core_dims=(("chain", "draw"), ), )) metric_names.append("circular mc error") metrics.append( xr.apply_ufunc( _make_ufunc(hpd, index=0, credible_interval=credible_interval, circular=True), posterior, input_core_dims=(("chain", "draw"), ), )) metric_names.append("circular hpd {:.2%}".format(alpha / 2)) metrics.append( xr.apply_ufunc( _make_ufunc(hpd, index=1, credible_interval=credible_interval, circular=True), posterior, input_core_dims=(("chain", "draw"), ), )) metric_names.append("circular hpd {:.2%}".format(1 - alpha / 2)) if len(posterior.chain) > 1: metrics.append(effective_n(posterior, var_names=var_names)) metric_names.append("eff_n") metrics.append(gelman_rubin(posterior, var_names=var_names)) metric_names.append("r_hat") joined = xr.concat(metrics, dim="metric").assign_coords(metric=metric_names) if fmt.lower() == "wide": dfs = [] for var_name, values in joined.data_vars.items(): if len(values.shape[1:]): metric = list(values.metric.values) data_dict = {} for idx in np.ndindex(values.shape[1:]): ser = pd.Series(values[(Ellipsis, *idx)].values, index=metric) key = "{}[{}]".format(var_name, ",".join(map(str, idx))) data_dict[key] = ser df = pd.DataFrame.from_dict(data_dict, orient="index") else: df = values.to_dataframe() df.index = list(df.index) df = df.T dfs.append(df) summary_df = pd.concat(dfs) elif fmt.lower() == "long": df = joined.to_dataframe().reset_index().set_index("metric") df.index = list(df.index) summary_df = df else: summary_df = joined return summary_df.round(round_to)
def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True): """Concatenate InferenceData objects. Concatenates over `group`, `chain` or `draw`. By default concatenates over unique groups. To concatenate over `chain` or `draw` function needs identical groups and variables. The `variables` in the `data` -group are merged if `dim` are not found. Parameters ---------- *args : InferenceData Variable length InferenceData list or Sequence of InferenceData. dim : str, optional If defined, concatenated over the defined dimension. Dimension which is concatenated. If None, concatenates over unique groups. copy : bool If True, groups are copied to the new InferenceData object. Used only if `dim` is None. inplace : bool If True, merge args to first object. reset_dim : bool Valid only if dim is not None. Returns ------- InferenceData A new InferenceData object by default. When `inplace==True` merge args to first arg and return `None` Examples -------- Use ``concat`` method to concatenate InferenceData objects. This will concatenates over unique groups by default. We first create an ``InferenceData`` object: .. ipython:: In [1]: import arviz as az ...: import numpy as np ...: data = { ...: "a": np.random.normal(size=(4, 100, 3)), ...: "b": np.random.normal(size=(4, 100)), ...: } ...: coords = {"a_dim": ["x", "y", "z"]} ...: dataA = az.from_dict(data, coords=coords, dims={"a": ["a_dim"]}) ...: dataA We have created an ``InferenceData`` object with default group 'posterior'. Now, we will create another ``InferenceData`` object: .. ipython:: In [1]: dataB = az.from_dict(prior=data, coords=coords, dims={"a": ["a_dim"]}) ...: dataB We have created another ``InferenceData`` object with group 'prior'. Now, we will concatenate these two ``InferenceData`` objects: .. ipython:: In [1]: az.concat(dataA, dataB) Now, we will concatenate over chain (or draw). It requires identical groups and variables. Here we are concatenating two identical ``InferenceData`` objects over dimension chain: .. ipython:: In [1]: az.concat(dataA, dataA, dim="chain") It will create an ``InferenceData`` with the original group 'posterior'. In similar way, we can also concatenate over draws. """ # pylint: disable=undefined-loop-variable, too-many-nested-blocks if len(args) == 0: if inplace: return return InferenceData() if len(args) == 1 and isinstance(args[0], Sequence): args = args[0] # assert that all args are InferenceData for i, arg in enumerate(args): if not isinstance(arg, InferenceData): raise TypeError( "Concatenating is supported only" "between InferenceData objects. Input arg {} is {}".format(i, type(arg)) ) if dim is not None and dim.lower() not in {"group", "chain", "draw"}: msg = "Invalid `dim`: {}. Valid `dim` are {}".format(dim, '{"group", "chain", "draw"}') raise TypeError(msg) dim = dim.lower() if dim is not None else dim if len(args) == 1 and isinstance(args[0], InferenceData): if inplace: return None else: if copy: return deepcopy(args[0]) else: return args[0] current_time = str(datetime.now()) if not inplace: # Keep order for python 3.5 inference_data_dict = OrderedDict() if dim is None: arg0 = args[0] arg0_groups = ccopy(arg0._groups) args_groups = dict() # check if groups are independent # Concat over unique groups for arg in args[1:]: for group in arg._groups: if group in args_groups or group in arg0_groups: msg = ( "Concatenating overlapping groups is not supported unless `dim` is defined." ) msg += " Valid dimensions are `chain` and `draw`." raise TypeError(msg) group_data = getattr(arg, group) args_groups[group] = deepcopy(group_data) if copy else group_data # add arg0 to args_groups if inplace is False if not inplace: for group in arg0_groups: group_data = getattr(arg0, group) args_groups[group] = deepcopy(group_data) if copy else group_data other_groups = [group for group in args_groups if group not in SUPPORTED_GROUPS] for group in SUPPORTED_GROUPS + other_groups: if group not in args_groups: continue if inplace: arg0._groups.append(group) setattr(arg0, group, args_groups[group]) else: inference_data_dict[group] = args_groups[group] if inplace: other_groups = [ group for group in arg0_groups if group not in SUPPORTED_GROUPS ] + other_groups sorted_groups = [ group for group in SUPPORTED_GROUPS + other_groups if group in arg0._groups ] setattr(arg0, "_groups", sorted_groups) else: arg0 = args[0] arg0_groups = arg0._groups for arg in args[1:]: for group0 in arg0_groups: if group0 not in arg._groups: if group0 == "observed_data": continue msg = "Mismatch between the groups." raise TypeError(msg) for group in arg._groups: if group != "observed_data": # assert that groups are equal if group not in arg0_groups: msg = "Mismatch between the groups." raise TypeError(msg) # assert that variables are equal group_data = getattr(arg, group) group_vars = group_data.data_vars if not inplace and group in inference_data_dict: group0_data = inference_data_dict[group] else: group0_data = getattr(arg0, group) group0_vars = group0_data.data_vars for var in group0_vars: if var not in group_vars: msg = "Mismatch between the variables." raise TypeError(msg) for var in group_vars: if var not in group0_vars: msg = "Mismatch between the variables." raise TypeError(msg) var_dims = getattr(group_data, var).dims var0_dims = getattr(group0_data, var).dims if var_dims != var0_dims: msg = "Mismatch between the dimensions." raise TypeError(msg) if dim not in var_dims or dim not in var0_dims: msg = "Dimension {} missing.".format(dim) raise TypeError(msg) # xr.concat concatenated_group = xr.concat((group_data, group0_data), dim=dim) if reset_dim: concatenated_group[dim] = range(concatenated_group[dim].size) # handle attrs if hasattr(group0_data, "attrs"): group0_attrs = deepcopy(getattr(group0_data, "attrs")) else: group0_attrs = OrderedDict() if hasattr(group_data, "attrs"): group_attrs = getattr(group_data, "attrs") else: group_attrs = dict() # gather attrs results to group0_attrs for attr_key, attr_values in group_attrs.items(): group0_attr_values = group0_attrs.get(attr_key, None) equality = attr_values == group0_attr_values if hasattr(equality, "__iter__"): equality = np.all(equality) if equality: continue # handle special cases: if attr_key in ("created_at", "previous_created_at"): # check the defaults if not hasattr(group0_attrs, "previous_created_at"): group0_attrs["previous_created_at"] = [] if group0_attr_values is not None: group0_attrs["previous_created_at"].append(group0_attr_values) # check previous values if attr_key == "previous_created_at": if not isinstance(attr_values, list): attr_values = [attr_values] group0_attrs["previous_created_at"].extend(attr_values) continue # update "created_at" if group0_attr_values != current_time: group0_attrs[attr_key] = current_time group0_attrs["previous_created_at"].append(attr_values) elif attr_key in group0_attrs: combined_key = "combined_{}".format(attr_key) if combined_key not in group0_attrs: group0_attrs[combined_key] = [group0_attr_values] group0_attrs[combined_key].append(attr_values) else: group0_attrs[attr_key] = attr_values # update attrs setattr(concatenated_group, "attrs", group0_attrs) if inplace: setattr(arg0, group, concatenated_group) else: inference_data_dict[group] = concatenated_group else: # observed_data if group not in arg0_groups: setattr(arg0, group, deepcopy(group_data) if copy else group_data) arg0._groups.append(group) continue # assert that variables are equal group_data = getattr(arg, group) group_vars = group_data.data_vars group0_data = getattr(arg0, group) if not inplace: group0_data = deepcopy(group0_data) group0_vars = group0_data.data_vars for var in group_vars: if var not in group0_vars: var_data = getattr(group_data, var) arg0.observed_data[var] = var_data else: var_data = getattr(group_data, var) var0_data = getattr(group0_data, var) if dim in var_data.dims and dim in var0_data.dims: concatenated_var = xr.concat((group_data, group0_data), dim=dim) group0_data[var] = concatenated_var # handle attrs if hasattr(group0_data, "attrs"): group0_attrs = getattr(group0_data, "attrs") else: group0_attrs = OrderedDict() if hasattr(group_data, "attrs"): group_attrs = getattr(group_data, "attrs") else: group_attrs = dict() # gather attrs results to group0_attrs for attr_key, attr_values in group_attrs.items(): group0_attr_values = group0_attrs.get(attr_key, None) equality = attr_values == group0_attr_values if hasattr(equality, "__iter__"): equality = np.all(equality) if equality: continue # handle special cases: if attr_key in ("created_at", "previous_created_at"): # check the defaults if not hasattr(group0_attrs, "previous_created_at"): group0_attrs["previous_created_at"] = [] if group0_attr_values is not None: group0_attrs["previous_created_at"].append(group0_attr_values) # check previous values if attr_key == "previous_created_at": if not isinstance(attr_values, list): attr_values = [attr_values] group0_attrs["previous_created_at"].extend(attr_values) continue # update "created_at" if group0_attr_values != current_time: group0_attrs[attr_key] = current_time group0_attrs["previous_created_at"].append(attr_values) elif attr_key in group0_attrs: combined_key = "combined_{}".format(attr_key) if combined_key not in group0_attrs: group0_attrs[combined_key] = [group0_attr_values] group0_attrs[combined_key].append(attr_values) else: group0_attrs[attr_key] = attr_values # update attrs setattr(group0_data, "attrs", group0_attrs) if inplace: setattr(arg0, group, group0_data) else: inference_data_dict[group] = group0_data return None if inplace else InferenceData(**inference_data_dict)
phases = [1, 4, 8] region_list = [] lags = [-10, -5, 0, 5, 10] var = 'omega' for region in regions: phaselist = [] for phase in phases: laglist = [] for lag in lags: file = glob( datapath + f'/mjo_streamfunction/mjohadley_{var}{region}_phase{phase}_lag{lag}.nc' ) ds = xr.open_dataset( file[0]).assign_coords(lag=lag).expand_dims('lag') laglist.append(ds) ds_lags = xr.concat(laglist, dim='lag') phaselist.append( ds_lags.assign_coords(phase=phase).expand_dims('phase')) ds_phase = xr.concat(phaselist, dim='phase') region_list.append( ds_phase.assign_coords(region=region).expand_dims('region')) ds = xr.concat(region_list, dim='region') da = ds['w'] da.sel(region='atl', phase=8).plot(x='latitude', col='lag', col_wrap=2, robust=True) plt.show()
def __init__(self, Xds, Yds, X_var_dict, Y_var, norm=True, batch_size=32, shuffle=True, mean=None, std=None, load=True, Y_flatten=False, Y_dropna=False): """ defines a Tensorflow Data Generator from xarrays datasets Parameters ---------- Xds : The xarray Dataset with the Feature maps (predictors) if several variables and / or levels, they will be concatenated along a 'level' dimension and transposed to have the 'level' as the last dimension (will correspond to the 'channel' if multiple variables / levels) Yds : xarray dataset The Xarray Dataset with the target variable (instance, lat, lon) X_var_dict : dict Dictionary of the form {'var': level} for building the inputs. Use None for level if data is of single level Y_var : str The name of the variable to extract in Yds norm : bool, optional Whether or not to perform field normalisation of the inputs, by default True batch_size : int, optional The batch size, by default 32 shuffle : bool, optional if True, data is shuffled, by default True mean : xarray dataarray, optional if None, computes the field mean, by default None std : xarray dataarray, optional if None, compute the field std, by default None load : bool, optional if True, dataset is loaded in memory, by default True """ self.Xds = Xds self.Yds = Yds self.X_var_dict = X_var_dict self.norm = norm self.batch_size = batch_size self.shuffle = shuffle # sanity checks and renaming of dimensions rename_dic = { 'time': 'instance', 'latitude': 'lat', 'longitude': 'lon' } for k in rename_dic.keys(): try: self.Xds = self.Xds.rename({k: rename_dic[k]}) self.Yds = self.Yds.rename({k: rename_dic[k]}) except: pass # build X data (the features maps dataset) Xdata = [] generic_level = xr.DataArray([1], coords={'level': [1]}, dims=['level']) for var, levels in X_var_dict.items(): try: Xdata.append(Xds[var].sel(level=levels)) except ValueError: Xdata.append(Xds[var].expand_dims({'level': generic_level}, 1)) # build Ydata (the target dataset) if not 'level' in Yds.dims: Yds = Yds[Y_var].expand_dims({'level': generic_level}, 1) if not 'instance' in Yds.dims: Yds = Tds.rename({'time': 'instance'}) if Y_flatten: Yds = Yds.stack(z=('lat', 'lon')) if Y_dropna: Yds = Yds.dropna(dim='z') self.Ydata = Yds.transpose('instance', 'z', 'level') else: self.Ydata = Yds.transpose('instance', 'lat', 'lon', 'level') self.Xdata = xr.concat(Xdata, 'level').transpose('instance', 'lat', 'lon', 'level') # calculates the mean and std (field mean and std) self.mean = self.Xdata.mean( ('instance', 'lat', 'lon')).compute() if mean is None else mean self.std = self.Xdata.std( ('instance', 'lat', 'lon')).compute() if std is None else std # Normalize if self.norm: self.Xdata = (self.Xdata - self.mean) / self.std # number of instances in the whole dataset self.n_samples = self.Xdata.shape[0] self.on_epoch_end() # loading (optional) if load: print("Loading data into RAM") self.Xdata.load() self.Ydata.load()
def Calculate_Quantiles(data, dimension, quantiles=[0.00, 0.25, 0.75, 1.00]): quantile_array = [] for i, quantile in enumerate(quantiles): quantile_array.append(data.quantile(q=quantile, dim=dimension)) return xr.concat(quantile_array, dim="quantile")
def _tseries_gen(varname, component, ensemble, entries, cluster_in): """ generate a tseries for a particular ensemble member, return a Dataset object """ print_timestamp(f"varname={varname}") varname_resolved = _varname_resolved(varname, component) fnames = entries.loc[entries["ensemble"] == ensemble].files.tolist() print(fnames) with open(var_specs_fname, mode="r") as fptr: var_specs_all = yaml.safe_load(fptr) if varname in var_specs_all[component]["vars"]: var_spec = var_specs_all[component]["vars"][varname] else: var_spec = {} # use var specific reduce_dims if it exists, otherwise use reduce_dims for component if "reduce_dims" in var_spec: reduce_dims = var_spec["reduce_dims"] else: reduce_dims = var_specs_all[component]["reduce_dims"] # get rank of varname from first file, used to set time chunksize # approximate number of time levels, assuming all files have same number # save time encoding from first file, to restore it in the multi-file case # https://github.com/pydata/xarray/issues/2921 with xr.open_dataset(fnames[0]) as ds0: vardims = ds0[varname_resolved].dims rank = len(vardims) vertlen = ds0.dims[vardims[1]] if rank > 3 else 0 time_chunksize = 10 * 12 if rank < 4 else 6 ds0.chunk(chunks={time_name: time_chunksize}) time_encoding = ds0[time_name].encoding var_encoding = ds0[varname_resolved].encoding ds0_attrs = ds0.attrs ds0_encoding = ds0.encoding drop_var_names_loc = drop_var_names(component, ds0, varname_resolved) # instantiate cluster, if not provided via argument # ignore dashboard warnings when instantiating if cluster_in is None: if "ncar_jobqueue" in sys.modules: with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", module=".*dashboard") cluster = ncar_jobqueue.NCARCluster() else: raise ValueError( "cluster_in not provided and ncar_jobqueue did not load successfully" ) else: cluster = cluster_in workers = 12 if vertlen >= 20: workers *= 2 if vertlen >= 60: workers *= 2 workers = 2 * round(workers / 2) # round to nearest multiple of 2 print_timestamp(f"calling cluster.scale({workers})") cluster.scale(workers) print_timestamp(f"dashboard_link={cluster.dashboard_link}") # create dask distributed client, connecting to workers with dask.distributed.Client(cluster) as client: print_timestamp("client instantiated") # tool to help track down file inconsistencies that trigger errors in open_mfdataset # test_open_mfdataset(fnames, time_chunksize, varname) # data_vars = "minimal", to avoid introducing time dimension to time-invariant fields when there are multiple files # only chunk in time, because if you chunk over spatial dims, then sum results depend on chunksize # https://github.com/pydata/xarray/issues/2902 with xr.open_mfdataset( fnames, data_vars="minimal", coords="minimal", compat="override", combine="by_coords", chunks={time_name: time_chunksize}, drop_variables=drop_var_names_loc, ) as ds_in: print_timestamp("open_mfdataset returned") # restore encoding for time from first file ds_in[time_name].encoding = time_encoding da_in_full = ds_in[varname_resolved] da_in_full.encoding = var_encoding var_units = clean_units(da_in_full.attrs["units"]) if "unit_conv" in var_spec: var_units = f"({var_spec['unit_conv']})({var_units})" # construct averaging/integrating weight weight = get_weight(ds_in, component, reduce_dims) weight_attrs = weight.attrs weight = get_rmask(ds_in, component) * weight weight.attrs = weight_attrs print_timestamp("weight constructed") # compute regional sum of weights da_in_t0 = da_in_full.isel({time_name: 0}).drop(time_name) ones_masked_t0 = xr.ones_like(da_in_t0).where(da_in_t0.notnull()) weight_sum = (ones_masked_t0 * weight).sum(dim=reduce_dims) weight_sum.name = f"weight_sum_{varname}" weight_sum.attrs = weight.attrs weight_sum.attrs[ "long_name" ] = f"sum of weights used in tseries generation for {varname}" tlen = da_in_full.sizes[time_name] print_timestamp(f"tlen={tlen}") # use var specific tseries_op if it exists, otherwise use tseries_op for component if "tseries_op" in var_spec: tseries_op = var_spec["tseries_op"] else: tseries_op = var_specs_all[component]["tseries_op"] ds_out_list = [] time_step_nominal = min(2 * workers * time_chunksize, tlen) time_step = math.ceil(tlen / (tlen // time_step_nominal)) print_timestamp(f"time_step={time_step}") for time_ind0 in range(0, tlen, time_step): print_timestamp(f"time_ind={time_ind0}, {time_ind0 + time_step}") da_in = da_in_full.isel( {time_name: slice(time_ind0, time_ind0 + time_step)} ) if tseries_op == "integrate": da_out = (da_in * weight).sum(dim=reduce_dims) da_out.name = varname da_out.attrs["long_name"] = "Integrated " + da_in.attrs["long_name"] da_out.attrs["units"] = cf_units.Unit( f"({weight.attrs['units']})({var_units})" ).format() elif tseries_op == "average": da_out = (da_in * weight).sum(dim=reduce_dims) ones_masked = xr.ones_like(da_in).where(da_in.notnull()) denom = (ones_masked * weight).sum(dim=reduce_dims) da_out /= denom da_out.name = varname da_out.attrs["long_name"] = "Averaged " + da_in.attrs["long_name"] da_out.attrs["units"] = cf_units.Unit(var_units).format() else: msg = f"tseries_op={tseries_op} not implemented" raise NotImplementedError(msg) print_timestamp("da_out computation setup") # propagate some settings from da_in to da_out da_out.encoding["dtype"] = da_in.encoding["dtype"] copy_fill_settings(da_in, da_out) ds_out = da_out.to_dataset() print_timestamp("ds_out generated") # copy particular variables from ds_in copy_var_list = [time_name] if "bounds" in ds_in[time_name].attrs: copy_var_list.append(ds_in[time_name].attrs["bounds"]) copy_var_list.extend(copy_var_names(component)) ds_out = xr.merge( [ ds_out, ds_in[copy_var_list].isel( {time_name: slice(time_ind0, time_ind0 + time_step)} ), ] ) print_timestamp("copy_var_names added") # force computation of ds_out, while resources of client are still available print_timestamp("calling ds_out.load") ds_out_list.append(ds_out.load()) print_timestamp("returned from ds_out.load") print_timestamp("concatenating ds_out_list datasets") ds_out = xr.concat( ds_out_list, dim=time_name, data_vars=[varname], coords="minimal", compat="override", ) # set ds_out.time to mid-interval values ds_out = time_set_mid(ds_out, time_name) print_timestamp("time_set_mid returned") # copy file attributes ds_out.attrs = ds0_attrs datestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z") msg = f"{datestamp}: created by {__file__}" if "history" in ds_out.attrs: ds_out.attrs["history"] = "\n".join([msg, ds_out.attrs["history"]]) else: ds_out.attrs["history"] = msg ds_out.attrs["input_file_list"] = " ".join(fnames) for key in ["unlimited_dims"]: if key in ds0_encoding: ds_out.encoding[key] = ds0_encoding[key] # restore encoding for time from first file ds_out[time_name].encoding = time_encoding # change output units, if specified in var_spec units_key = ( "integral_display_units" if tseries_op == "integrate" else "display_units" ) if units_key in var_spec: ds_out[varname] = conv_units(ds_out[varname], var_spec[units_key]) print_timestamp("units converted") # add regional sum of weights ds_out[weight_sum.name] = weight_sum print_timestamp("ds_in and client closed") # if cluster was instantiated here, close it if cluster_in is None: cluster.close() return ds_out
def mergeFiles(config, name, start_date, end_date, execution_days, unbeaching): """ This code merges previous runs. It takes into account the 'extra' date that is repeated for each splitted file :param config: :param start_date: :param end_date: :param execution_days: :return: """ exec_next = True if MPI: if MPI.COMM_WORLD.Get_rank() != 0: exec_next = False # The next code we want that only a single CPU executes it if exec_next: cur_end_date = min(start_date + timedelta(days=execution_days), end_date) part_n = 0 while (cur_end_date < end_date): input_file = get_file_name(name, start_date, cur_end_date, part_n) restart_file = join( config[GlobalModel.output_folder], F"{input_file}{config[GlobalModel.output_file]}") print(F"Reading restart file: {restart_file}") if part_n == 0: merged_data = xr.open_dataset(restart_file) timevar = merged_data['time'].copy() trajectory = merged_data['trajectory'].copy() lat = merged_data['lat'].copy() lon = merged_data['lon'].copy() z = merged_data['z'].copy() if (unbeaching): beached = merged_data['beached'].copy() beached_count = merged_data['beached_count'].copy() else: temp_data = xr.open_dataset(restart_file) timevar = xr.concat([timevar, temp_data['time'][:, 1:]], dim='obs') trajectory = xr.concat( [trajectory, temp_data['trajectory'][:, 1:]], dim='obs') lat = xr.concat([lat, temp_data['lat'][:, 1:]], dim='obs') lon = xr.concat([lon, temp_data['lon'][:, 1:]], dim='obs') z = xr.concat([z, temp_data['z'][:, 1:]], dim='obs') if (unbeaching): beached = xr.concat([beached, temp_data['beached'][:, 1:]], dim='obs') beached_count = xr.concat( [beached_count, temp_data['beached_count'][:, 1:]], dim='obs') start_date = cur_end_date # We need to add one or we will repeat a day cur_end_date = min(start_date + timedelta(days=execution_days), end_date) part_n += 1 print("Done adding this file!") # Last call print("Adding last file and merging all....") input_file = get_file_name(name, start_date, cur_end_date, part_n) restart_file = join(config[GlobalModel.output_folder], F"{input_file}{config[GlobalModel.output_file]}") temp_data = xr.open_dataset(restart_file) # The first location is already saved on the previous file timevar = xr.concat([timevar, temp_data['time'][:, 1:]], dim='obs') trajectory = xr.concat([trajectory, temp_data['trajectory'][:, 1:]], dim='obs') lat = xr.concat([lat, temp_data['lat'][:, 1:]], dim='obs') lon = xr.concat([lon, temp_data['lon'][:, 1:]], dim='obs') z = xr.concat([z, temp_data['z'][:, 1:]], dim='obs') if (unbeaching): beached = xr.concat([beached, temp_data['beached'][:, 1:]], dim='obs') beached_count = xr.concat( [beached_count, temp_data['beached_count'][:, 1:]], dim='obs') # Here we have all the variables merged, we need to create a new Dataset and save it if (unbeaching): ds = xr.Dataset({ "time": (("traj", "obs"), timevar), "trajectory": (("traj", "obs"), trajectory), "lat": (("traj", "obs"), lat), "lon": (("traj", "obs"), lon), "z": (("traj", "obs"), z), "beached": (("traj", "obs"), beached), "beached_count": (("traj", "obs"), beached_count), }) else: ds = xr.Dataset({ "time": (("traj", "obs"), timevar), "trajectory": (("traj", "obs"), trajectory), "lat": (("traj", "obs"), lat), "lon": (("traj", "obs"), lon), "z": (("traj", "obs"), z), }) ds.attrs = temp_data.attrs output_file = join(config[GlobalModel.output_folder], F"{name}{config[GlobalModel.output_file]}") ds.to_netcdf(output_file) print("REAL DONE DONE DONE!")
def regrid_spec(dset, freq=None, dir=None, maintain_m0=True): """Regrid spectra onto new spectral basis. Args: - dset (Dataset, DataArray): Spectra to interpolate. - freq (DataArray, 1darray): Frequencies of interpolated spectra (Hz). - dir (DataArray, 1darray): Directions of interpolated spectra (deg). - maintain_m0 (bool): Ensure variance is conserved in interpolated spectra. Returns: - dsi (Dataset, DataArray): Regridded spectra. Note: - All freq below lowest freq are interpolated assuming :math:`E_d(f=0)=0`. - :math:`Ed(f)` is set to zero for new freq above the highest freq in dset. - Only the 'linear' method is currently supported. - Duplicate wrapped directions (e.g., 0 and 360) are removed when regridding directions because indices must be unique to intepolate. """ dsout = dset.copy() if dir is not None: dsout = dsout.assign_coords({attrs.DIRNAME: dsout[attrs.DIRNAME] % 360}) # Remove any duplicate direction index dsout = unique_indices(dsout, attrs.DIRNAME) # Interpolate heading dsout = dsout.sortby('dir') to_concat = [dsout] # Repeat the first and last direction with 360 deg offset when required if dir.min() < dsout.dir.min(): highest = dsout.isel(dir=-1) highest['dir'] = highest.dir - 360 to_concat = [highest, dsout] if dir.max() > dsout.dir.max(): lowest = dsout.isel(dir=0) lowest['dir'] = lowest.dir + 360 to_concat.append(lowest) if len(to_concat) > 1: dsout = xr.concat(to_concat, dim='dir') # Interpolate directions dsout = dsout.interp(dir=dir, assume_sorted=True) if freq is not None: # If needed, add a new frequency at f=0 with zero energy if freq.min() < dsout.freq.min(): fzero = 0 * dsout.isel(freq=0) fzero['freq'] = 0 dsout = xr.concat([fzero, dsout], dim='freq') # Interpolate frequencies dsout = dsout.interp(freq=freq, assume_sorted=False, kwargs={'fill_value': 0}) if maintain_m0: scale = dset.spec.hs()**2 / dsout.spec.hs()**2 dsout = dsout * scale return dsout
def merge_member(ds_list: list) -> xr.Dataset: return xr.concat(ds_list, dim='member', coords='minimal')
@pytest.mark.parametrize("dataset", [dset_a, dset_b, dset_c["Tair"]]) @pytest.mark.parametrize("freq", ["day", "month", "year", "season"]) def test_anomaly_setup(dataset, freq): computed_dset = geocat.comp.anomaly(dataset, freq) assert type(dataset) == type(computed_dset) ds1 = get_fake_dataset(start_month="2000-01", nmonths=12, nlats=1, nlons=1) # Create another dataset for the year 2001. ds2 = get_fake_dataset(start_month="2001-01", nmonths=12, nlats=1, nlons=1) # Create a dataset that combines the two previous datasets, for two # years of data. ds3 = xr.concat([ds1, ds2], dim="time") # Create a dataset with the wrong number of months. partial_year_dataset = get_fake_dataset(start_month="2000-01", nmonths=13, nlats=1, nlons=1) # Create a dataset with a custom time coordinate. custom_time_dataset = get_fake_dataset(start_month="2000-01", nmonths=12, nlats=1, nlons=1) custom_time_dataset = custom_time_dataset.rename({"time": "my_time"}) # Create a more complex dataset just to verify that get_fake_dataset()
import matplotlib.pyplot as plt import pandas as pd import eofdata ds = xr.open_dataset('./data/PV_monthly.nc') ds.coords['year'] = np.arange(1981,2017) #junto la cordenada year y number para categorizar enventos ds = ds.stack(realiz = ['year', 'number']).compute() print(ds) #compute seasona means: ASO and SON PV_aso = ds.sel(**{'month':slice(8,10)}).mean(dim='month') PV_son = ds.sel(**{'month':slice(9,11)}).mean(dim='month') PV_seasonal = xr.concat([PV_aso, PV_son], dim='season') #select upper and lower quartile PV_aso_lower = PV_aso.quantile(0.25, dim='realiz', interpolation='linear') PV_aso_upper = PV_aso.quantile(0.75, dim='realiz', interpolation='linear') PV_son_lower = PV_aso.quantile(0.25, dim='realiz', interpolation='linear') PV_son_upper = PV_aso.quantile(0.75, dim='realiz', interpolation='linear') PV_monthly_lower = ds.quantile(0.25, dim='realiz', interpolation='linear') PV_monthly_upper = ds.quantile(0.75, dim='realiz', interpolation='linear') ds.reset_index('realiz').to_netcdf('./data/PV_monthly2.nc') [lamb, v, PC] = eofdata.eofdata(ds.z.values[0:4, :], 3) print(v[:, 0]) PV_monthly_index = PC[0, :]
def get_GCM_outputs(provider='CDS', GCM='ECMWF', var_name='T2M', period='hindcasts', rpath=None, domain=[90, 300, -65, 50], step=None, verbose=False, flatten=True, nmembers=None): """ Get the GCM outputs Parameters ---------- - provider: in ['CDS','IRI','JMA'] - GCM: name of the GCM - var_name: in ['T2M', 'PRECIP'] - period: in ['hindcasts', 'forecasts'] - rpath (root path, pathlib.Path object, see `set_root_dir` in the utils module) - domain [lon_min, lon_max, lat_min, lat_max] - step ( in [3, 4, 5] ) - verbose: Boolean, whether to print names of files successfully opened - flatten: Boolean, whether of not to stack the dataset over the spatial (+ member if present) dimension to get 2D fields Return ------ - dset: xarray.Dataset concatenated along the time dimension """ import pathlib import xarray as xr ipath = rpath / 'GCMs' / 'processed' / period / provider / GCM / var_name lfiles_gcm = list( ipath.glob(f"{GCM}_{var_name}_seasonal_anomalies_interp_????_??.nc")) if (period == 'hindcasts') and (len(lfiles_gcm)) < 200: print( f"Something wrong with the number of files in the list for the {period} period, the length is {len(lfiles_gcm)}" ) if (period == 'forecasts') and (len(lfiles_gcm)) < 20: print( f"Something wrong with the number of files in the list for the {period} period, the length is {len(lfiles_gcm)}" ) lfiles_gcm.sort() print(f"first file is {str(lfiles_gcm[0])}") print(f"last file is {str(lfiles_gcm[-1])}") dset_l = [] for fname in lfiles_gcm: dset = xr.open_dataset(fname)[[var_name.lower()]] if 'surface' in dset.dims: dset = dset.drop('surface') # select the domain if domain is not None: dset = dset.sel(lon=slice(domain[0], domain[1]), lat=slice(domain[2], domain[3])) if step is not None: dset = dset.sel(step=step) if verbose: print(f"successfully opened and extracted {fname}") dset_l.append(dset) dset = xr.concat(dset_l, dim='time', coords='minimal', compat='override') # now get the coordinates, will be returned along with the dataset itself, # regarding of whether the dataset is flattened #dims_tuple = (dset.dims, dset[var_name.lower()].dims) if nmembers is not None: dset = dset.isel(member=slice(0, nmembers)) coords = dset.coords if flatten: if 'member' in dset.dims: dset = dset.stack(z=('member', 'lat', 'lon')) else: dset = dset.stack(z=('lat', 'lon')) return dset, coords
def run_job(metadata, transformation_name, variable, transformation, rcp, pername, years, model, read_acct, baseline_model, seasons, unit, agglev, aggwt, weights=None): logger.debug('Beginning job\nkwargs:\t{}'.format( pprint.pformat(metadata, indent=2))) # Add to job metadata metadata.update(dict(time_horizon='{}-{}'.format(years[0], years[-1]))) baseline_file = BASELINE_FILE.format(**metadata) pattern_file = BCSD_pattern_files.format(**metadata) write_file = WRITE_PATH.format(**metadata) # do not duplicate if os.path.isfile(write_file): return del metadata['read_acct'] # Get transformed data total = None seasonal_baselines = {} for season in seasons: basef = baseline_file.format(season=season) logger.debug('attempting to load baseline file: {}'.format(basef)) seasonal_baselines[season] = load_baseline(basef, variable) season_month_start = {'DJF': 12, 'MAM': 3, 'JJA': 6, 'SON': 9} for year in years: seasonal = [] for s, season in enumerate(seasons): pattf = pattern_file.format(year=year, season=season) logger.debug('attempting to load pattern file: {}'.format(pattf)) patt = load_bcsd(pattf, variable, broadcast_dims=('day', )) logger.debug('{} {} {} - reindexing coords day --> time'.format( model, year, season)) patt = (patt.assign_coords( time=xr.DataArray(pd.period_range('{}-{}-1'.format( year - int(season == 'DJF'), season_month_start[season]), periods=len(patt.day), freq='D'), coords={'day': patt.day})).swap_dims({ 'day': 'time' }).drop('day')) logger.debug( '{} {} {} - adding pattern residuals to baseline'.format( model, year, season)) seasonal.append(patt + seasonal_baselines[season]) logger.debug(('{} {} - concatenating seasonal data and ' + 'applying transform').format(model, year)) annual = xr.Dataset( {variable: xr.concat(seasonal, dim='time').pipe(transformation)}) if total is None: total = annual else: total += annual ds = total / len(years) # Reshape to regions logger.debug('{} reshaping to regions'.format(model)) if not agglev.startswith('grid'): ds = weighted_aggregate_grid_to_regions(ds, variable, aggwt, agglev, weights=weights) # Update netCDF metadata logger.debug('{} udpate metadata'.format(model)) ds.attrs.update(**metadata) # Write output logger.debug('attempting to write to file: {}'.format(write_file)) if not os.path.isdir(os.path.dirname(write_file)): os.makedirs(os.path.dirname(write_file)) ds.to_netcdf(write_file)
def _compute_horizontal_transport_mpas(ds, dsMesh, outFileName): ''' compute the horizontal transport through edges on the native MPAS grid. ''' if file_complete(ds, outFileName): return nVertLevels = dsMesh.sizes['nVertLevels'] cellsOnEdge = dsMesh.cellsOnEdge - 1 maxLevelCell = dsMesh.maxLevelCell - 1 cell0 = cellsOnEdge[:, 0] cell1 = cellsOnEdge[:, 1] internalEdgeIndices = xarray.DataArray(numpy.nonzero( numpy.logical_and(cell0.values >= 0, cell1.values >= 0))[0], dims=('nInternalEdges', )) cell0 = cell0[internalEdgeIndices] cell1 = cell1[internalEdgeIndices] bottomDepth = dsMesh.bottomDepth maxLevelEdgeTop = maxLevelCell[cell0] mask = numpy.logical_or(cell0 == -1, maxLevelCell[cell1] < maxLevelEdgeTop) maxLevelEdgeTop[mask] = maxLevelCell[cell1][mask] nVertLevels = dsMesh.sizes['nVertLevels'] vertIndex = \ xarray.DataArray.from_dict({'dims': ('nVertLevels',), 'data': numpy.arange(nVertLevels)}) ds = ds.chunk({'Time': 1}) chunks = {'nInternalEdges': 1024} maxLevelEdgeTop = maxLevelEdgeTop.chunk(chunks) dvEdge = dsMesh.dvEdge[internalEdgeIndices].chunk(chunks) bottomDepthEdge = 0.5 * (bottomDepth[cell0] + bottomDepth[cell1]).chunk(chunks) chunks = {'Time': 1, 'nInternalEdges': 1024} normalVelocity = ds.timeMonthly_avg_normalVelocity.isel( nEdges=internalEdgeIndices).chunk(chunks) layerThickness = ds.timeMonthly_avg_layerThickness.chunk() layerThicknessEdge = 0.5 * (layerThickness.isel( nCells=cell0) + layerThickness.isel(nCells=cell1)).chunk(chunks) layerThicknessEdge = layerThicknessEdge.where(vertIndex <= maxLevelEdgeTop, other=0.) thicknessSum = layerThicknessEdge.sum(dim='nVertLevels') thicknessCumSum = layerThicknessEdge.cumsum(dim='nVertLevels') zSurface = thicknessSum - bottomDepthEdge zInterfaceEdge = -thicknessCumSum + zSurface zInterfaceEdge = xarray.concat([ zSurface.expand_dims(dim='nVertLevelsP1', axis=2), zInterfaceEdge.rename({'nVertLevels': 'nVertLevelsP1'}) ], dim='nVertLevelsP1') transportPerDepth = dvEdge * normalVelocity dsOut = xarray.Dataset() dsOut['xtime_startMonthly'] = ds.xtime_startMonthly dsOut['xtime_endMonthly'] = ds.xtime_endMonthly dsOut['zInterfaceEdge'] = zInterfaceEdge dsOut['layerThicknessEdge'] = layerThicknessEdge dsOut['transportPerDepth'] = transportPerDepth dsOut['transportVertSum'] = \ (transportPerDepth*layerThicknessEdge).sum(dim='nVertLevels') dsOut = dsOut.transpose('Time', 'nInternalEdges', 'nVertLevels', 'nVertLevelsP1') print('compute and caching transport on MPAS grid:') write_netcdf(dsOut, outFileName, progress=True)
def dem_(source=None, lon_min=-180, lon_max=180, lat_min=-90, lat_max=90, **kwargs): ncores = kwargs.get('ncores', NCORES) xr_kwargs = kwargs.get('dem_xr_kwargs', {}) #--------------------------------------------------------------------- logger.info('extracting dem from {}\n'.format(source)) #--------------------------------------------------------------------- data = xr.open_dataset(source, **xr_kwargs) #rename vars,coords var = [keys for keys in data.data_vars] coords = [keys for keys in data.coords] lat = [x for x in coords if 'lat' in x] lon = [x for x in coords if 'lon' in x] data = data.rename({ var[0]: 'elevation', lat[0]: 'latitude', lon[0]: 'longitude' }) #recenter the window lon0 = lon_min + 360. if lon_min < -180 else lon_min lon1 = lon_max + 360. if lon_max < -180 else lon_max lon0 = lon0 - 360. if lon0 > 180 else lon0 lon1 = lon1 - 360. if lon1 > 180 else lon1 # TODO check this for regional files if (lon_min < data.longitude.min()) or (lon_max > data.longitude.max()): logger.warning('Lon must be within {} and {}'.format( data.longitude.min().values, data.longitude.max().values)) logger.warning('compensating if global dataset available') # sys.exit() if (lat_min < data.latitude.min()) or (lat_max > data.latitude.max()): logger.warning('Lat must be within {} and {}'.format( data.latitude.min().values, data.latitude.max().values)) logger.warning('compensating if global dataset available') # sys.exit() #get idx i0 = np.abs(data.longitude.data - lon0).argmin() i1 = np.abs(data.longitude.data - lon1).argmin() j0 = np.abs(data.latitude.data - lat_min).argmin() j1 = np.abs(data.latitude.data - lat_max).argmin() # expand the window a little bit lon_0 = max(0, i0 - 2) lon_1 = min(data.longitude.size, i1 + 2) lat_0 = max(0, j0 - 2) lat_1 = min(data.latitude.size, j1 + 2) # descenting lats if j0 > j1: j0, j1 = j1, j0 lat_0 = max(0, j0 - 1) lat_1 = min(data.latitude.size, j1 + 3) if i0 > i1: p1 = (data.elevation.isel(longitude=slice(lon_0, data.longitude.size), latitude=slice(lat_0, lat_1))) p1 = p1.assign_coords({'longitude': p1.longitude.values - 360.}) p2 = (data.elevation.isel(longitude=slice(0, lon_1), latitude=slice(lat_0, lat_1))) dem = xr.concat([p1, p2], dim='longitude') else: dem = (data.elevation.isel(longitude=slice(lon_0, lon_1), latitude=slice(lat_0, lat_1))) if np.abs(np.mean(dem.longitude) - np.mean([lon_min, lon_max])) > 170.: c = np.sign(np.mean([lon_min, lon_max])) dem['longitude'] = dem['longitude'] + c * 360. if 'grid_x' in kwargs.keys(): #--------------------------------------------------------------------- logger.info('.. interpolating on grid ..\n') #--------------------------------------------------------------------- grid_x = kwargs.get('grid_x', None) grid_y = kwargs.get('grid_y', None) # resample on the given grid xx, yy = np.meshgrid(dem.longitude, dem.latitude) #original grid # Translate for pyresample if xx.mean() < 0 and xx.min() < -180.: xx = xx + 180. gx = grid_x + 180. elif xx.mean() > 0 and xx.max() > 180.: xx = xx - 180. gx = grid_x - 180. else: gx = grid_x orig = pyresample.geometry.SwathDefinition(lons=xx, lats=yy) # original points targ = pyresample.geometry.SwathDefinition(lons=gx, lats=grid_y) # target grid wet = kwargs.get('wet_only', False) if wet: #mask positive bathymetry vals = np.ma.masked_array(dem, dem.values > 0) else: vals = dem.values # with nearest using only the water values itopo = pyresample.kd_tree.resample_nearest( orig, dem.values, targ, radius_of_influence=100000, fill_value=np.nan) #,nprocs=ncores) if len(grid_x.shape) > 1: idem = xr.Dataset({ 'ival': (['k', 'l'], itopo), 'ilons': (['k', 'l'], grid_x), 'ilats': (['k', 'l'], grid_y) }) #, # coords={'ilon': ('ilon', grid_x[0,:]), # 'ilat': ('ilat', grid_y[:,0])}) elif len(grid_x.shape) == 1: idem = xr.Dataset({ 'ival': (['k'], itopo), 'ilons': (['k'], grid_x), 'ilats': (['k'], grid_y) }) #--------------------------------------------------------------------- logger.info('dem done\n') #--------------------------------------------------------------------- return xr.merge([dem, idem]) else: return xr.merge([dem])
def pixel_drill(task_id=None): parameters = parse_parameters_from_task(task_id=task_id) validate_parameters(parameters, task_id=task_id) task = FractionalCoverTask.objects.get(pk=task_id) if task.status == "ERROR": return None dc = DataAccessApi(config=task.config_path) single_pixel = dc.get_stacked_datasets_by_extent(**parameters) clear_mask = task.satellite.get_clean_mask_func()(single_pixel.isel( latitude=0, longitude=0)) single_pixel = single_pixel.where( single_pixel != task.satellite.no_data_value) dates = single_pixel.time.values if len(dates) < 2: task.update_status( "ERROR", "There is only a single acquisition for your parameter set.") return None def _apply_band_math(ds, idx): # mask out water manually. Necessary for frac. cover. wofs = wofs_classify(ds, clean_mask=clear_mask[idx], mosaic=True) clear_mask[ idx] = False if wofs.wofs.values[0] == 1 else clear_mask[idx] fractional_cover = frac_coverage_classify( ds, clean_mask=clear_mask[idx], no_data=task.satellite.no_data_value) return fractional_cover fractional_cover = xr.concat([ _apply_band_math(single_pixel.isel(time=data_point, drop=True), data_point) for data_point in range(len(dates)) ], dim='time') fractional_cover = fractional_cover.where( fractional_cover != task.satellite.no_data_value).isel(latitude=0, longitude=0) exclusion_list = [] plot_measurements = [ band for band in fractional_cover.data_vars if band not in exclusion_list ] datasets = [ fractional_cover[band].values.transpose() for band in plot_measurements ] + [clear_mask] data_labels = [ stringcase.titlecase("%{}".format(band)) for band in plot_measurements ] + ["Clear"] titles = [ 'Bare Soil Percentage', 'Photosynthetic Vegetation Percentage', 'Non-Photosynthetic Vegetation Percentage', 'Clear Mask' ] style = ['r-o', 'g-o', 'b-o', '.'] task.plot_path = os.path.join(task.get_result_path(), "plot_path.png") create_2d_plot(task.plot_path, dates=dates, datasets=datasets, data_labels=data_labels, titles=titles, style=style) task.complete = True task.update_status("OK", "Done processing pixel drill.")
def fetch(self, grouped: VirtualDatasetBox, **load_settings: Dict[str, Any]) -> xarray.Dataset: def is_from(source_index): def result(_, value): self._assert('collate' in value, "malformed dataset box in collate") return value['collate'][0] == source_index return result def strip_source(_, value): return value['collate'][1] def fetch_child(child, source_index, r): if any([x == 0 for x in r.box.shape]): # empty raster return None else: result = child.fetch(r, **load_settings) name = self.get('index_measurement_name') if name is None: return result # implication for dask? measurement = Measurement(name=name, dtype='int8', nodata=-1, units='1') shape = select_unique( [result[band].shape for band in result.data_vars]) array = numpy.full(shape, source_index, dtype=measurement.dtype) first = result[list(result.data_vars)[0]] result[name] = xarray.DataArray(array, dims=first.dims, coords=first.coords, name=name).assign_attrs( units=measurement.units, nodata=measurement.nodata) return result groups = [ fetch_child( child, source_index, grouped.filter(is_from(source_index)).map(strip_source)) for source_index, child in enumerate(self._children) ] non_empty = [g for g in groups if g is not None] dim = self.get('dim', 'time') result = xarray.concat(non_empty, dim=dim).sortby(dim).assign_attrs( **select_unique([g.attrs for g in non_empty])) # concat and sortby mess up chunking if 'dask_chunks' not in load_settings or dim not in load_settings[ 'dask_chunks']: return result return result.apply( lambda x: x.chunk({dim: load_settings['dask_chunks'][dim]}), keep_attrs=True)
"targets") temp.attrs["sources"], temp.attrs["targets"] = np.triu_indices( temp.shape[0], 1) for f in range(temp.shape[-2]): for t in range(temp.shape[-1]): np.fill_diagonal(temp[..., f, t].values, 1) MC += [temp] # From matrix to stream representation MCg = [] for mat in tqdm(MC): out = convert_to_stream(mat, mat.attrs["sources"], mat.attrs["targets"]) MCg += [out] # Average over repeated meta-edges across sessions MC_avg = xr.concat(MCg, dim="roi").groupby("roi").mean("roi") # Convert back to matrix x_s, x_t = _extract_roi(MC_avg.roi.values, "~") # Get rois unique_rois = np.unique(np.stack((x_s, x_t))) index_rois = dict(zip(unique_rois, range(len(unique_rois)))) temp = np.zeros((len(unique_rois), len(unique_rois), 10, 5)) for i, (s, t) in enumerate(zip(x_s, x_t)): i_s, i_t = index_rois[s], index_rois[t] temp[i_s, i_t, ...] = temp[i_t, i_s, ...] = MC_avg[i] temp = xr.DataArray( temp, dims=("sources", "targets", "freqs", "times"),
### Lessons learned: the groupby('time.hour') command requires the time has regular values, just check print(T2_WRF_May.coords['time'][0:24]) print(T2_WRF_JJA.coords['time'][0:24]) ### the time has regular values, so you can use the following command WRF_May = T2_WRF_May.groupby('time.hour').mean() WRF_JJA = T2_WRF_JJA.groupby('time.hour').mean() ### ARM SGP obs: ARMBE2DGRID from Qi Tang temp05 = ds_ARMBE2D_05['temp'] temp06 = ds_ARMBE2D_06['temp'] temp07 = ds_ARMBE2D_07['temp'] temp08 = ds_ARMBE2D_08['temp'] temp_0678 = xr.concat([temp06, temp07, temp08], dim='time') temp_May = temp05.sel(lat=slice(lat_1, lat_2), lon=slice(lon_1, lon_2)).mean(dim='lat').mean(dim='lon') temp_JJA = temp_0678.sel(lat=slice(lat_1, lat_2), lon=slice(lon_1, lon_2)).mean(dim='lat').mean(dim='lon') ### the time coords values are irregular, so you cannot use groupby('time.hour').mean() to ### calculate diurnal cycle, just use for loops print(temp_JJA['time'][0:24]) ARM_May = np.zeros(24) ARM_JJA = np.zeros(24) for i in np.arange(0, 31, 1):
def load_data(sc='mms1', instr='fgm', mode='srvy', level='l2', optdesc=None, start_date=None, end_date=None, offline=False, record_dim='Epoch', team_site=False, data_type='science', **kwargs): """ Load MMS data. Empty files are silently skipped. NoVariablesInFileError is raised only if all files in time interval are empty. Parameters ---------- sc : str Spacecraft ID: ('mms1', 'mms2', 'mms3', 'mms4') instr : str Instrument ID mode : str Instrument mode: ('slow', 'fast', 'srvy', 'brst'). optdesc : str Optional descriptor for dataset start_date, end_date : `datetime.datetime` Start and end of the data interval. offline : bool If True, search only for local files record_dim : str Name of the record varying dimension. This is the dimension along which the data from different files will be concatenated. If *None*, the name of the leading dimension of the first data variable will be used. team_site : bool If True, search the password-protected team site data_type : str Type of data to download. ('science', 'hk', 'ancillary') \*\*kwargs : dict Keywords passed to *cdf_to_ds* Returns ------- data : `xarray.DataArray` or list The requested data. If data from all files can be concatenated successfully, a Dataset is returned. If not, a list of Datasets is returned, where each dataset is the data from a single file. """ if start_date is None: start_date = np.datetime64('2015-10-16T13:06:04') if end_date is None: end_date = np.datetime64('2015-10-16T13:07:20') site = 'public' if team_site: site = 'private' # Download the data sdc = api.MrMMS_SDC_API(sc, instr, mode, level, optdesc=optdesc, start_date=start_date, end_date=end_date, data_type=data_type, offline=offline) # The data level parameter will automatically set the site keyword. # If the user specifies the site, set it after instantiation. sdc.site = site files = sdc.download_files() try: files = api.sort_files(files)[0] except IndexError: raise IndexError('No files found: {0}'.format(sdc)) # Read all of the data files. Skip empty files unless all files are empty data = [] for file in files: try: data.append(cdf_to_ds(file, **kwargs)) except NoVariablesInFileError: pass if len(data) == 0: raise NoVariablesInFileError('All {0} files were empty.'.format( len(files))) # Determine the name of the record varying dimension. This should be the # value of the DEPEND_0 attribute of a data variable. if record_dim is None: varnames = [name for name in data[0].data_vars] rec_vname = data[0][varnames[0]].dims[0] else: rec_vname = record_dim # Notes: # 1. Concatenation can fail if, e.g., a variable does not have a # coordinate assigned along a given dimension. Instead of crashing, # return the list of datasets so that they can be corrected and # concatenated externally. # # 2. If data variables in the dataset do not have the dimension # identified by rec_vname, a new dimension is added. If the dataset is # large, this can cause xarray/python to use all available ram and # crash. A fix would be to 1) find all DEPEND_0 variables, 2) use the # data_vars='minimal' option to concat for each one, 3) combine the # resulting datasets back together. # # 3. If there is only one dataset in the list and that dataset is empty # then xr.concat will return the dataset even if the dim=rec_vname is # not present. try: data = xr.concat(data, dim=rec_vname) except (MemoryError, Exception) as E: return data # cdf_to_df loads all of the data from the file. Now we need to trim to # the time interval of interest try: data = data.sel(indexers={rec_vname: slice(start_date, end_date)}) except KeyError: warnings.warn('{0} out unordered; cannot slice.'.format(rec_vname)) # Keep information about the data data.attrs['sc'] = sc data.attrs['instr'] = instr data.attrs['mode'] = mode data.attrs['level'] = level data.attrs['optdesc'] = optdesc data.attrs['files'] = files return data
def concat_updated_states(i, EnKF_result_basedir, meas_times, output_concat_states_dir, global_template): ''' A wrap function that concatenates EnKF updated SM states. Parameters ---------- i: <int> Index of ensemble member (starting from 0) EnKF_result_basedir: <str> EnKF output result base directory meas_times: <list or pandas.tseries.index.DatetimeIndex> Measurement time points; the same as state updating time points output_concat_states_dir: <str> Directory for outputing concatenated SM state files global_template: <str> VIC global file path Returns ---------- da_state_all_times: <xr.DataArray> Concatenated SM states for this ensemble member ''' # --- Load states at measurement times --- # list_da_state = [] list_da_swe = [] for t in meas_times: state_nc = os.path.join( EnKF_result_basedir, 'states', 'updated.{}_{:05d}'.format(t.strftime('%Y%m%d'), t.hour * 3600 + t.second), 'state.ens{}.nc'.format(i + 1)) da_state = xr.open_dataset(state_nc)['STATE_SOIL_MOISTURE'] da_swe = xr.open_dataset(state_nc)['STATE_SNOW_WATER_EQUIVALENT'] list_da_state.append(da_state) list_da_swe.append(da_swe) # --- Concatenate states of all time together --- # da_state_all_times = xr.concat(list_da_state, dim='time') da_state_all_times['time'] = meas_times da_swe_all_times = xr.concat(list_da_swe, dim='time') da_swe_all_times['time'] = meas_times # # --- Save concatenated states to netCDF file --- # # ds_state_all_times = xr.Dataset( # {'STATE_SOIL_MOISTURE': da_state_all_times, # 'STATE_SNOW_WATER_EQUIVALENT': da_swe_all_times}) # out_nc = os.path.join( # output_concat_states_dir, # 'updated_state.{}_{}.ens{}.nc'.format( # meas_times[0].strftime('%Y%m%d'), # meas_times[-1].strftime('%Y%m%d'), # i+1)) # to_netcdf_state_file_compress( # ds_state_all_times, out_nc) # Calculate and save cell-average states to netCDF file da_tile_frac = determine_tile_frac(global_template) da_state_cellAvg = (da_state_all_times * da_tile_frac).sum( dim='veg_class').sum(dim='snow_band') # [time, nlayer, lat, lon] da_swe_cellAvg = (da_swe_all_times * da_tile_frac).sum( dim='veg_class').sum(dim='snow_band') # [time, lat, lon] ds_state_cellAvg = xr.Dataset({ 'SOIL_MOISTURE': da_state_cellAvg, 'SWE': da_swe_cellAvg }) out_nc = os.path.join( output_concat_states_dir, 'updated_state_cellAvg.{}_{}.ens{}.nc'.format( meas_times[0].strftime('%Y%m%d'), meas_times[-1].strftime('%Y%m%d'), i + 1)) to_netcdf_state_file_compress(ds_state_cellAvg, out_nc)
def ensemble_common_handler(process: Process, request, response, subset_function): assert subset_function in [ finch_subset_bbox, finch_subset_gridpoint, finch_subset_shape, ] xci_inputs = process.xci_inputs_identifiers request_inputs_not_datasets = { k: v for k, v in request.inputs.items() if k in xci_inputs and not k.startswith('perc') } percentile_vars = { f"{k.split('_')[1]}_per" for k in xci_inputs if k.startswith('perc') } dataset_input_names = accepted_variables.intersection( percentile_vars.union(xci_inputs)) source_variable_names = bccaq_variables.intersection( get_sub_inputs(dataset_input_names)) convert_to_csv = request.inputs["output_format"][0].data == "csv" if not convert_to_csv: del process.status_percentage_steps["convert_to_csv"] percentiles_string = request.inputs["ensemble_percentiles"][0].data ensemble_percentiles = [ int(p.strip()) for p in percentiles_string.split(",") ] rcps = [r.data.strip() for r in request.inputs["rcp"]] write_log(process, f"Processing started ({len(rcps)} rcps)", process_step="start") models = [m.data.strip() for m in request.inputs["models"]] dataset_name = single_input_or_none(request.inputs, "dataset") if single_input_or_none(request.inputs, "average"): if subset_function == finch_subset_gridpoint: average_dims = ("region", ) else: average_dims = ("lat", "lon") else: average_dims = None write_log(process, f"Will average over {average_dims}") base_work_dir = Path(process.workdir) ensembles = [] for rcp in rcps: # Ensure no file name conflicts (i.e. if the rcp doesn't appear in the base filename) work_dir = base_work_dir / rcp work_dir.mkdir(exist_ok=True) process.set_workdir(str(work_dir)) write_log(process, f"Fetching datasets for rcp={rcp}") output_filename = make_output_filename(process, request.inputs, rcp=rcp) netcdf_inputs = get_datasets( dataset_name, workdir=process.workdir, variables=list(source_variable_names), rcp=rcp, models=models, ) write_log(process, f"Running subset rcp={rcp}", process_step="subset") subsetted_files = subset_function(process, netcdf_inputs=netcdf_inputs, request_inputs=request.inputs) if not subsetted_files: message = "No data was produced when subsetting using the provided bounds." raise ProcessError(message) subsetted_intermediate_files = compute_intermediate_variables( subsetted_files, dataset_input_names, process.workdir, request.inputs, ) write_log(process, f"Computing indices rcp={rcp}", process_step="compute_indices") input_groups = make_indicator_inputs(process.xci, request_inputs_not_datasets, subsetted_intermediate_files) n_groups = len(input_groups) indices_files = [] warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) for n, inputs in enumerate(input_groups): write_log( process, f"Computing indices for file {n + 1} of {n_groups}, rcp={rcp}", subtask_percentage=n * 100 // n_groups, ) output_ds = compute_indices(process, process.xci, inputs) output_name = f"{output_filename}_{process.identifier}_{n}.nc" for variable in accepted_variables: if variable in inputs: input_name = Path(inputs.get(variable)[0].file).name output_name = input_name.replace(variable, process.identifier) output_path = Path(process.workdir) / output_name dataset_to_netcdf(output_ds, output_path) indices_files.append(output_path) warnings.filterwarnings("default", category=FutureWarning) warnings.filterwarnings("default", category=UserWarning) output_basename = Path( process.workdir) / (output_filename + "_ensemble") ensemble = make_ensemble(indices_files, ensemble_percentiles, average_dims) ensemble.attrs['source_datasets'] = '\n'.join( [dsinp.url for dsinp in netcdf_inputs]) ensembles.append(ensemble) process.set_workdir(str(base_work_dir)) ensemble = xr.concat(ensembles, dim=xr.DataArray(rcps, dims=('rcp', ), name='rcp')) if convert_to_csv: ensemble_csv = output_basename.with_suffix(".csv") prec = single_input_or_none(request.inputs, "csv_precision") if prec: ensemble = ensemble.round(prec) df = dataset_to_dataframe(ensemble) if average_dims is None: dims = ['lat', 'lon', 'time'] else: dims = ['time'] df = df.reset_index().set_index(dims) if "region" in df.columns: df.drop(columns="region", inplace=True) df.dropna().to_csv(ensemble_csv) metadata = format_metadata(ensemble) metadata_file = output_basename.parent / f"{ensemble_csv.stem}_metadata.txt" metadata_file.write_text(metadata) ensemble_output = Path(process.workdir) / (output_filename + ".zip") zip_files(ensemble_output, [metadata_file, ensemble_csv]) else: ensemble_output = output_basename.with_suffix(".nc") dataset_to_netcdf(ensemble, ensemble_output) response.outputs["output"].file = ensemble_output response.outputs["output_log"].file = str(log_file_path(process)) write_log(process, "Processing finished successfully", process_step="done") return response
data = model_data[delta].drop(labels=["ix", "iy"]).transpose( "time", "y", "x") #imod-python does not want other dimensions data = data.assign_coords(y=data.y * -1) #Flip y-coorindates dsts = [xr.full_like(data, np.nan).sel(time=t) for t in data.time] src_crs = dst_crs = rio.crs.CRS.from_epsg(32630) dst_transform = imod.util.transform(dsts[0]) for i, t in enumerate(data.time): dsts[i].values, _ = rio.warp.reproject(data.sel(time=t).values, dsts[i].values, src_crs=src_crs, dst_crs=dst_crs, dst_transform=dst_transform, gcps=gcps[delta]) warp_data[delta] = xr.concat(dsts, dim="time") #%%Clip out delta print("...clipping...") for delta, da in warp_data.items(): dcell = geom.loc[delta, "dx"] targgrid["x"], targgrid["y"] = gm.get_targgrid(dcell, dcell, df.loc[delta, "L_a"], df.loc[delta, "phi_f"] / 2) coords = {"x": targgrid["x"][0], "y": targgrid["y"][:, 0]} like = xr.DataArray(np.ones(targgrid["x"].shape), coords=coords, dims=["y", "x"]) da = da.sortby(da.y) #Ensure y is monotonically increasing
def fetch(self, grouped: VirtualDatasetBox, **load_settings: Dict[str, Any]) -> xarray.Dataset: """ Convert grouped datasets to `xarray.Dataset`. """ geobox = grouped.geobox measurements = self.output_measurements(grouped.product_definitions) band_settings = dict( zip( list(measurements), per_band_load_data_settings(measurements, resampling=self.get( 'resampling', 'nearest')))) boxes = [ VirtualDatasetBox(box_slice.box, None, True, box_slice.product_definitions, geopolygon=geobox.extent) for box_slice in grouped.split() ] dask_chunks = load_settings.get('dask_chunks') if dask_chunks is None: rasters = [ self._input.fetch(box, **load_settings) for box in boxes ] else: rasters = [ self._input.fetch(box, dask_chunks={ key: 1 for key in dask_chunks if key not in geobox.dims }, **reject_keys(load_settings, ['dask_chunks'])) for box in boxes ] result = xarray.Dataset() result.coords['time'] = grouped.box.coords['time'] for name, coord in grouped.geobox.coordinates.items(): result.coords[name] = (name, coord.values, { 'units': coord.units, 'resolution': coord.resolution }) for measurement in measurements: result[measurement] = xarray.concat([ reproject_band(raster[measurement], geobox, band_settings[measurement]['resampling_method'], grouped.box.dims + geobox.dims, dask_chunks) for raster in rasters ], dim='time') result.attrs['crs'] = geobox.crs return result
def plot_ess( idata, var_names=None, kind="local", relative=False, coords=None, figsize=None, textsize=None, rug=False, rug_kind="diverging", n_points=20, extra_methods=False, min_ess=400, ax=None, extra_kwargs=None, text_kwargs=None, hline_kwargs=None, rug_kwargs=None, backend=None, backend_kwargs=None, show=None, **kwargs ): """Plot quantile, local or evolution of effective sample sizes (ESS). Parameters ---------- idata : obj Any object that can be converted to an az.InferenceData object Refer to documentation of az.convert_to_dataset for details var_names : list of variable names, optional Variables to be plotted. kind : str, optional Options: ``local``, ``quantile`` or ``evolution``, specify the kind of plot. relative : bool Show relative ess in plot ``ress = ess / N``. coords : dict, optional Coordinates of var_names to be plotted. Passed to `Dataset.sel` figsize : tuple, optional Figure size. If None it will be defined automatically. textsize: float, optional Text size scaling factor for labels, titles and lines. If None it will be autoscaled based on figsize. rug : bool Plot rug plot of values diverging or that reached the max tree depth. rug_kind : bool Variable in sample stats to use as rug mask. Must be a boolean variable. n_points : int Number of points for which to plot their quantile/local ess or number of subsets in the evolution plot. extra_methods : bool, optional Plot mean and sd ESS as horizontal lines. Not taken into account in evolution kind min_ess : int Minimum number of ESS desired. ax: axes, optional Matplotlib axes or bokeh figures. extra_kwargs : dict, optional If evolution plot, extra_kwargs is used to plot ess tail and differentiate it from ess bulk. Otherwise, passed to extra methods lines. text_kwargs : dict, optional Only taken into account when ``extra_methods=True``. kwargs passed to ax.annotate for extra methods lines labels. It accepts the additional key ``x`` to set ``xy=(text_kwargs["x"], mcse)`` hline_kwargs : dict, optional kwargs passed to ax.axhline for the horizontal minimum ESS line. rug_kwargs : dict kwargs passed to rug plot. backend: str, optional Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". backend_kwargs: bool, optional These are kwargs specific to the backend being used. For additional documentation check the plotting method of the backend. show : bool, optional Call backend show function. **kwargs Passed as-is to plt.hist() or plt.plot() function depending on the value of `kind`. Returns ------- axes : matplotlib axes or bokeh figures References ---------- * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008 Examples -------- Plot local ESS. This plot, together with the quantile ESS plot, is recommended to check that there are enough samples for all the explored regions of parameter space. Checking local and quantile ESS is particularly relevant when working with credible intervals as opposed to ESS bulk, which is relevant for point estimates. .. plot:: :context: close-figs >>> import arviz as az >>> idata = az.load_arviz_data("centered_eight") >>> coords = {"school": ["Choate", "Lawrenceville"]} >>> az.plot_ess( ... idata, kind="local", var_names=["mu", "theta"], coords=coords ... ) Plot quantile ESS. .. plot:: :context: close-figs >>> az.plot_ess( ... idata, kind="quantile", var_names=["mu", "theta"], coords=coords ... ) Plot ESS evolution as the number of samples increase. When the model is converging properly, both lines in this plot should be roughly linear. .. plot:: :context: close-figs >>> az.plot_ess( ... idata, kind="evolution", var_names=["mu", "theta"], coords=coords ... ) Customize local ESS plot to look like reference paper. .. plot:: :context: close-figs >>> az.plot_ess( ... idata, kind="local", var_names=["mu"], drawstyle="steps-mid", color="k", ... linestyle="-", marker=None, rug=True, rug_kwargs={"color": "r"} ... ) Customize ESS evolution plot to look like reference paper. .. plot:: :context: close-figs >>> extra_kwargs = {"color": "lightsteelblue"} >>> az.plot_ess( ... idata, kind="evolution", var_names=["mu"], ... color="royalblue", extra_kwargs=extra_kwargs ... ) """ valid_kinds = ("local", "quantile", "evolution") kind = kind.lower() if kind not in valid_kinds: raise ValueError("Invalid kind, kind must be one of {} not {}".format(valid_kinds, kind)) if coords is None: coords = {} if "chain" in coords or "draw" in coords: raise ValueError("chain and draw are invalid coordinates for this kind of plot") extra_methods = False if kind == "evolution" else extra_methods data = get_coords(convert_to_dataset(idata, group="posterior"), coords) var_names = _var_names(var_names, data) n_draws = data.dims["draw"] n_samples = n_draws * data.dims["chain"] ess_tail_dataset = None mean_ess = None sd_ess = None text_x = None text_va = None if kind == "quantile": probs = np.linspace(1 / n_points, 1 - 1 / n_points, n_points) xdata = probs ylabel = "{} for quantiles" ess_dataset = xr.concat( [ ess(data, var_names=var_names, relative=relative, method="quantile", prob=p) for p in probs ], dim="ess_dim", ) elif kind == "local": probs = np.linspace(0, 1, n_points, endpoint=False) xdata = probs ylabel = "{} for small intervals" ess_dataset = xr.concat( [ ess( data, var_names=var_names, relative=relative, method="local", prob=[p, p + 1 / n_points], ) for p in probs ], dim="ess_dim", ) else: first_draw = data.draw.values[0] ylabel = "{}" xdata = np.linspace(n_samples / n_points, n_samples, n_points) draw_divisions = np.linspace(n_draws // n_points, n_draws, n_points, dtype=int) ess_dataset = xr.concat( [ ess( data.sel(draw=slice(first_draw + draw_div)), var_names=var_names, relative=relative, method="bulk", ) for draw_div in draw_divisions ], dim="ess_dim", ) ess_tail_dataset = xr.concat( [ ess( data.sel(draw=slice(first_draw + draw_div)), var_names=var_names, relative=relative, method="tail", ) for draw_div in draw_divisions ], dim="ess_dim", ) plotters = filter_plotters_list( list(xarray_var_iter(ess_dataset, var_names=var_names, skip_dims={"ess_dim"})), "plot_ess" ) length_plotters = len(plotters) rows, cols = default_grid(length_plotters) (figsize, ax_labelsize, titlesize, xt_labelsize, _linewidth, _markersize) = _scale_fig_size( figsize, textsize, rows, cols ) kwargs = matplotlib_kwarg_dealiaser(kwargs, "plot") _linestyle = "-" if kind == "evolution" else "none" kwargs.setdefault("linestyle", _linestyle) kwargs.setdefault("linewidth", _linewidth) kwargs.setdefault("markersize", _markersize) kwargs.setdefault("marker", "o") kwargs.setdefault("zorder", 3) extra_kwargs = matplotlib_kwarg_dealiaser(extra_kwargs, "plot") if kind == "evolution": extra_kwargs = { **extra_kwargs, **{key: item for key, item in kwargs.items() if key not in extra_kwargs}, } kwargs.setdefault("label", "bulk") extra_kwargs.setdefault("label", "tail") else: extra_kwargs.setdefault("linestyle", "-") extra_kwargs.setdefault("linewidth", _linewidth / 2) extra_kwargs.setdefault("color", "k") extra_kwargs.setdefault("alpha", 0.5) kwargs.setdefault("label", kind) hline_kwargs = matplotlib_kwarg_dealiaser(hline_kwargs, "plot") hline_kwargs.setdefault("linewidth", _linewidth) hline_kwargs.setdefault("linestyle", "--") hline_kwargs.setdefault("color", "gray") hline_kwargs.setdefault("alpha", 0.7) if extra_methods: mean_ess = ess(data, var_names=var_names, method="mean", relative=relative) sd_ess = ess(data, var_names=var_names, method="sd", relative=relative) text_kwargs = matplotlib_kwarg_dealiaser(text_kwargs, "text") text_x = text_kwargs.pop("x", 1) text_kwargs.setdefault("fontsize", xt_labelsize * 0.7) text_kwargs.setdefault("alpha", extra_kwargs["alpha"]) text_kwargs.setdefault("color", extra_kwargs["color"]) text_kwargs.setdefault("horizontalalignment", "right") text_va = text_kwargs.pop("verticalalignment", None) essplot_kwargs = dict( ax=ax, plotters=plotters, xdata=xdata, ess_tail_dataset=ess_tail_dataset, mean_ess=mean_ess, sd_ess=sd_ess, idata=idata, data=data, text_x=text_x, text_va=text_va, kind=kind, extra_methods=extra_methods, rows=rows, cols=cols, figsize=figsize, kwargs=kwargs, extra_kwargs=extra_kwargs, text_kwargs=text_kwargs, _linewidth=_linewidth, _markersize=_markersize, n_samples=n_samples, relative=relative, min_ess=min_ess, xt_labelsize=xt_labelsize, titlesize=titlesize, ax_labelsize=ax_labelsize, ylabel=ylabel, rug=rug, rug_kind=rug_kind, rug_kwargs=rug_kwargs, hline_kwargs=hline_kwargs, backend_kwargs=backend_kwargs, show=show, ) if backend is None: backend = rcParams["plot.backend"] backend = backend.lower() # TODO: Add backend kwargs plot = get_plotting_function("plot_ess", "essplot", backend) ax = plot(**essplot_kwargs) return ax
def run_task(self): # {{{ ''' Compute time series of regional profiles ''' # Authors # ------- # Milena Veneziani, Mark Petersen, Phillip J. Wolfram, Xylar Asay-Davis self.logger.info("\nCompute time series of regional profiles...") startDate = '{:04d}-01-01_00:00:00'.format(self.startYear) endDate = '{:04d}-12-31_23:59:59'.format(self.endYear) timeSeriesName = self.parentTask.regionMaskSuffix outputDirectory = '{}/{}/'.format( build_config_full_path(self.config, 'output', 'timeseriesSubdirectory'), timeSeriesName) try: os.makedirs(outputDirectory) except OSError: pass outputFileName = '{}/{}_{:04d}-{:04d}.nc'.format( outputDirectory, timeSeriesName, self.startYear, self.endYear) inputFiles = sorted( self.historyStreams.readpath('timeSeriesStatsMonthlyOutput', startDate=startDate, endDate=endDate, calendar=self.calendar)) years, months = get_files_year_month(inputFiles, self.historyStreams, 'timeSeriesStatsMonthlyOutput') variableList = [field['mpas'] for field in self.parentTask.fields] outputExists = os.path.exists(outputFileName) outputValid = outputExists if outputExists: with open_mpas_dataset(fileName=outputFileName, calendar=self.calendar, timeVariableNames=None, variableList=None, startDate=startDate, endDate=endDate) as dsIn: for inIndex in range(dsIn.dims['Time']): mask = np.logical_and(dsIn.year[inIndex].values == years, dsIn.month[inIndex].values == months) if np.count_nonzero(mask) == 0: outputValid = False break if outputValid: self.logger.info(' Time series exists -- Done.') return # get areaCell restartFileName = \ self.runStreams.readpath('restart')[0] dsRestart = xr.open_dataset(restartFileName) dsRestart = dsRestart.isel(Time=0) areaCell = dsRestart.areaCell nVertLevels = dsRestart.sizes['nVertLevels'] vertIndex = \ xr.DataArray.from_dict({'dims': ('nVertLevels',), 'data': np.arange(nVertLevels)}) vertMask = vertIndex < dsRestart.maxLevelCell # get region masks regionMaskFileName = self.parentTask.masksSubtask.maskFileName dsRegionMask = xr.open_dataset(regionMaskFileName) # figure out the indices of the regions to plot regionNames = decode_strings(dsRegionMask.regionNames) regionIndices = [] for regionToPlot in self.parentTask.regionNames: for index, regionName in enumerate(regionNames): if regionToPlot == regionName: regionIndices.append(index) break # select only those regions we want to plot dsRegionMask = dsRegionMask.isel(nRegions=regionIndices) cellMasks = dsRegionMask.regionCellMasks regionNamesVar = dsRegionMask.regionNames totalArea = (cellMasks * areaCell * vertMask).sum('nCells') datasets = [] for timeIndex, fileName in enumerate(inputFiles): dsLocal = open_mpas_dataset(fileName=fileName, calendar=self.calendar, variableList=variableList, startDate=startDate, endDate=endDate) dsLocal = dsLocal.isel(Time=0) time = dsLocal.Time.values date = days_to_datetime(time, calendar=self.calendar) self.logger.info(' date: {:04d}-{:02d}'.format( date.year, date.month)) # for each region and variable, compute area-weighted sum and # squared sum for field in self.parentTask.fields: variableName = field['mpas'] prefix = field['prefix'] self.logger.info(' {}'.format(field['titleName'])) var = dsLocal[variableName].where(vertMask) meanName = '{}_mean'.format(prefix) dsLocal[meanName] = \ (cellMasks * areaCell * var).sum('nCells') / totalArea meanSquaredName = '{}_meanSquared'.format(prefix) dsLocal[meanSquaredName] = \ (cellMasks * areaCell * var**2).sum('nCells') / totalArea # drop the original variables dsLocal = dsLocal.drop(variableList) datasets.append(dsLocal) # combine data sets into a single data set dsOut = xr.concat(datasets, 'Time') dsOut.coords['regionNames'] = regionNamesVar dsOut['totalArea'] = totalArea dsOut.coords['year'] = (('Time'), years) dsOut['year'].attrs['units'] = 'years' dsOut.coords['month'] = (('Time'), months) dsOut['month'].attrs['units'] = 'months' # Note: restart file, not a mesh file because we need refBottomDepth, # not in a mesh file try: restartFile = self.runStreams.readpath('restart')[0] except ValueError: raise IOError('No MPAS-O restart file found: need at least one ' 'restart file for plotting time series vs. depth') with xr.open_dataset(restartFile) as dsRestart: depths = dsRestart.refBottomDepth.values z = np.zeros(depths.shape) z[0] = -0.5 * depths[0] z[1:] = -0.5 * (depths[0:-1] + depths[1:]) dsOut.coords['z'] = (('nVertLevels'), z) dsOut['z'].attrs['units'] = 'meters' write_netcdf(dsOut, outputFileName)
def merge_batches(gcm_name, output_sim_fp=input.output_sim_fp, rcp=None, option_remove_merged_files=0, option_remove_batch_files=0, debug=False): """ MERGE BATCHES """ #for gcm_name in ['CCSM4', 'GFDL-CM3', 'GFDL-ESM2M', 'GISS-E2-R', 'IPSL-CM5A-LR', 'MIROC5', 'MRI-CGCM3', 'NorESM1-M']: # debug=True # netcdf_fp = input.output_sim_fp + gcm_name + '/' splitter = '_batch' zipped_fp = output_sim_fp + 'spc_zipped/' merged_fp = output_sim_fp + 'spc_merged/' netcdf_fp = output_sim_fp + gcm_name + '/' # Check file path exists if os.path.exists(zipped_fp) == False: os.makedirs(zipped_fp) if os.path.exists(merged_fp) == False: os.makedirs(merged_fp) regions = [] rcps = [] for i in os.listdir(netcdf_fp): if i.endswith('.nc'): i_region = int(i.split('_')[0][1:]) if i_region not in regions: regions.append(i_region) if gcm_name not in ['ERA-Interim']: i_rcp = i.split('_')[2] if i_rcp not in rcps: rcps.append(i_rcp) regions = sorted(regions) rcps = sorted(rcps) # Set RCPs if not for GCM and/or if overriding with argument if len(rcps) == 0: rcps = [None] if rcp is not None: rcps = [rcp] if debug: print('Regions:', regions, '\nRCPs:', rcps) # Encoding # Add variables to empty dataset and merge together encoding = {} noencoding_vn = ['stats', 'glac_attrs'] if input.output_package == 2: for vn in input.output_variables_package2: # Encoding (specify _FillValue, offsets, etc.) if vn not in noencoding_vn: encoding[vn] = {'_FillValue': False} for reg in regions: check_str = 'R' + str(reg) + '_' + gcm_name for rcp in rcps: if rcp is not None: check_str = 'R' + str(reg) + '_' + gcm_name + '_' + rcp if debug: print('Region(s)', reg, 'RCP', rcp, ':', 'check_str:', check_str) output_list = [] merged_list = [] for i in os.listdir(netcdf_fp): if i.startswith(check_str) and splitter in i: output_list.append( [int(i.split(splitter)[1].split('.')[0]), i]) output_list = sorted(output_list) output_list = [i[1] for i in output_list] # Open datasets and combine count_ds = 0 for i in output_list: if debug: print(i) count_ds += 1 ds = xr.open_dataset(netcdf_fp + i) # Merge datasets of stats into one output if count_ds == 1: ds_all = ds else: ds_all = xr.concat([ds_all, ds], dim='glac') ds_all.glac.values = np.arange(0, len(ds_all.glac.values)) ds_all_fn = i.split(splitter)[0] + '.nc' # Export to netcdf ds_all.to_netcdf(merged_fp + ds_all_fn, encoding=encoding) print('Merged ', gcm_name, rcp, 'Region(s)', reg) merged_list.append(merged_fp + ds_all_fn) if debug: print(merged_list) # Zip file to reduce file size with zipfile.ZipFile(zipped_fp + ds_all_fn + '.zip', mode='w', compression=zipfile.ZIP_DEFLATED) as myzip: myzip.write(merged_fp + ds_all_fn, arcname=ds_all_fn) # Remove unzipped files if option_remove_merged_files == 1: for i in merged_list: os.remove(i) if option_remove_batch_files == 1: # Remove batch files for i in output_list: os.remove(netcdf_fp + i)
def merge_time(ds_list: list) -> xr.Dataset: return xr.concat(ds_list, dim='time', coords='minimal')
print("Reading %s" % (os.path.basename(file))) data = xr.open_dataset(file) # Define new time coordinates and drop the old one data['time'].reset_coords(drop=True) data.coords['time'] = ([ dt.datetime(data['YEAR'], data['MONTH'], 1), ]) if (data.lat.ndim > 1): print("Changing coordinates...") data.coords['x'] = data.lat.lon[0].data data.coords['y'] = data.lat[:, 0].data data = data.drop(('lon', 'lat')) data = data.rename({'x': 'lon', 'y': 'lat'}) data_list.append(data['dry_O3']) # Concatenating the list dry_dep_raw_data.append(xr.concat(data_list, dim='time')) # Extracting some general information # WARNING: cdo has summed these also up -> devide by number of days gridarea = data['gridarea'].isel(time=0) / 31. #molarweight = get_molarweight(data.isel(time=0))/31. ozone_raw_data = [] height_raw_data = [] for iexp in experiment: subdir = data_dir + iexp + mm_dir + '*.nc' print("Reading from path %s" % (os.path.abspath(subdir))) data_list = [] data_list_height = [] # Open dataset for file in sorted(glob.glob(subdir)): print("Reading %s" % (os.path.basename(file)))
def test_concat_loads_variables(self): # Test that concat() computes not-in-memory variables at most once # and loads them in the output, while leaving the input unaltered. d1 = build_dask_array('d1') c1 = build_dask_array('c1') d2 = build_dask_array('d2') c2 = build_dask_array('c2') d3 = build_dask_array('d3') c3 = build_dask_array('c3') # Note: c is a non-index coord. # Index coords are loaded by IndexVariable.__init__. ds1 = Dataset(data_vars={'d': ('x', d1)}, coords={'c': ('x', c1)}) ds2 = Dataset(data_vars={'d': ('x', d2)}, coords={'c': ('x', c2)}) ds3 = Dataset(data_vars={'d': ('x', d3)}, coords={'c': ('x', c3)}) assert kernel_call_count == 0 out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different', coords='different') # each kernel is computed exactly once assert kernel_call_count == 6 # variables are loaded in the output assert isinstance(out['d'].data, np.ndarray) assert isinstance(out['c'].data, np.ndarray) out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='all', coords='all') # no extra kernel calls assert kernel_call_count == 6 assert isinstance(out['d'].data, dask.array.Array) assert isinstance(out['c'].data, dask.array.Array) out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=['d'], coords=['c']) # no extra kernel calls assert kernel_call_count == 6 assert isinstance(out['d'].data, dask.array.Array) assert isinstance(out['c'].data, dask.array.Array) out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=[], coords=[]) # variables are loaded once as we are validing that they're identical assert kernel_call_count == 12 assert isinstance(out['d'].data, np.ndarray) assert isinstance(out['c'].data, np.ndarray) out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different', coords='different', compat='identical') # compat=identical doesn't do any more kernel calls than compat=equals assert kernel_call_count == 18 assert isinstance(out['d'].data, np.ndarray) assert isinstance(out['c'].data, np.ndarray) # When the test for different turns true halfway through, # stop computing variables as it would not have any benefit ds4 = Dataset(data_vars={'d': ('x', [2.0])}, coords={'c': ('x', [2.0])}) out = xr.concat([ds1, ds2, ds4, ds3], dim='n', data_vars='different', coords='different') # the variables of ds1 and ds2 were computed, but those of ds3 didn't assert kernel_call_count == 22 assert isinstance(out['d'].data, dask.array.Array) assert isinstance(out['c'].data, dask.array.Array) # the data of ds1 and ds2 was loaded into numpy and then # concatenated to the data of ds3. Thus, only ds3 is computed now. out.compute() assert kernel_call_count == 24 # Finally, test that riginals are unaltered assert ds1['d'].data is d1 assert ds1['c'].data is c1 assert ds2['d'].data is d2 assert ds2['c'].data is c2 assert ds3['d'].data is d3 assert ds3['c'].data is c3