def test_create_index_variable_global(self): raise SkipTest('not implemented') dsrc_size = 5 ddst_size = 7 dsrc = Dimension('dsrc', dsrc_size, dist=True) src_dist = OcgDist() src_dist.add_dimension(dsrc) src_dist.update_dimension_bounds() ddst = Dimension('ddst', ddst_size, dist=True) dst_dist = OcgDist() dst_dist.add_dimension(ddst) dst_dist.update_dimension_bounds() if vm.rank == 0: np.random.seed(1) dst = np.random.rand(ddst_size) src = np.random.choice(dst, size=dsrc_size, replace=False) src = Variable(name='src', value=src, dimensions=dsrc.name) # TODO: move create_ugid_global to create_global_index on a standard variable object dst = GeometryVariable(name='dst', value=dst, dimensions=ddst.name) else: src, dst = [None] * 2 src = variable_scatter(src, src_dist) dst = variable_scatter(dst, dst_dist) actual = create_index_variable_global('index_array', src, dst) self.assertNumpyAll(dst.get_value()[actual.get_value()], src.get_value())
def test_set_spatial_mask(self): dmap = DimensionMap() dims = Dimension('x', 3), Dimension('y', 7) mask_var = create_spatial_mask_variable('a_mask', None, dims) self.assertFalse(np.any(mask_var.get_mask())) dmap.set_spatial_mask(mask_var) self.assertEqual(dmap.get_spatial_mask(), mask_var.name) with self.assertRaises(DimensionMapError): dmap.set_variable(DMK.SPATIAL_MASK, mask_var) # Test custom variables may be used. dmap = DimensionMap() dims = Dimension('x', 3), Dimension('y', 7) mask_var = create_spatial_mask_variable('a_mask', None, dims) attrs = {'please keep me': 'no overwriting'} dmap.set_spatial_mask(mask_var, attrs=attrs) attrs = dmap.get_attrs(DMK.SPATIAL_MASK) self.assertIn('please keep me', attrs) # Test default attributes are not added. dmap = DimensionMap() dmap.set_spatial_mask('foo', default_attrs={'blue': 'world'}) prop = dmap.get_property(DMK.SPATIAL_MASK) self.assertEqual(prop['attrs'], {'blue': 'world'})
def test_get_live_ranks_from_object(self): if MPI_SIZE != 4: raise SkipTest('MPI_SIZE != 4') vm = OcgVM() if MPI_RANK == 1: dim = Dimension('woot', is_empty=True, dist=True) else: dim = Dimension('woot', dist=True, size=3) actual = vm.get_live_ranks_from_object(dim) self.assertEqual(actual, (0, 2, 3)) vm.finalize()
def get_grouping(self, grouping): """ Create a temporally grouped variable using string group sequences. :param grouping: The temporal grouping to use when creating the temporal group dimension. >>> grouping = ['month'] :type grouping: `sequence` of :class:`str` :rtype: :class:`~ocgis.variable.temporal.TemporalGroupVariable` """ # There is no need to go through the process of breaking out datetime parts when the grouping is 'all'. if grouping == 'all': new_bounds, date_parts, repr_dt, dgroups = self._get_grouping_all_( ) # The process for getting "unique" seasons is also specialized. elif 'unique' in grouping: new_bounds, date_parts, repr_dt, dgroups = self._get_grouping_seasonal_unique_( grouping) # For standard groups ("['month']") or seasons across entire time range. else: new_bounds, date_parts, repr_dt, dgroups = self._get_grouping_other_( grouping) new_name = 'climatology_bounds' time_dimension_name = self.dimensions[0].name if self.has_bounds: new_dimensions = [d.name for d in self.bounds.dimensions] else: new_dimensions = [time_dimension_name, 'bounds'] # Create the new time dimension as unlimited if the original time variable also has an unlimited dimension. if self.dimensions[0].is_unlimited: new_time_dimension = Dimension(name=time_dimension_name, size_current=len(repr_dt)) else: new_time_dimension = time_dimension_name new_dimensions[0] = new_time_dimension new_bounds = TemporalVariable(value=new_bounds, name=new_name, dimensions=new_dimensions) new_attrs = deepcopy(self.attrs) # new_attrs['climatology'] = new_bounds.name tgv = TemporalGroupVariable(grouping=grouping, date_parts=date_parts, bounds=new_bounds, dgroups=dgroups, value=repr_dt, units=self.units, calendar=self.calendar, name=self.name, attrs=new_attrs, dimensions=new_dimensions[0]) tgv.attrs.pop(TemporalVariable._bounds_attribute_name, None) return tgv
def test_set_spatial_mask(self): dmap = DimensionMap() dims = Dimension('x', 3), Dimension('y', 7) mask_var = create_grid_mask_variable('a_mask', None, dims) self.assertFalse(np.any(mask_var.get_mask())) dmap.set_spatial_mask(mask_var) self.assertEqual(dmap.get_spatial_mask(), mask_var.name) with self.assertRaises(DimensionMapError): dmap.set_variable(DMK.SPATIAL_MASK, mask_var) # Test custom variables may be used. dmap = DimensionMap() dims = Dimension('x', 3), Dimension('y', 7) mask_var = create_grid_mask_variable('a_mask', None, dims) attrs = {'please keep me': 'no overwriting'} dmap.set_spatial_mask(mask_var, attrs=attrs) attrs = dmap.get_attrs(DMK.SPATIAL_MASK) self.assertIn('please keep me', attrs)
def test_system_grid_chunking(self): if vm.size != 4: raise SkipTest('vm.size != 4') from ocgis.spatial.grid_chunker import GridChunker path = self.path_esmf_unstruct rd_dst = RequestDataset(uri=path, driver=DriverESMFUnstruct, crs=Spherical(), grid_abstraction='point', grid_is_isomorphic=True) rd_src = deepcopy(rd_dst) resolution = 0.28125 chunk_wd = os.path.join(self.current_dir_output, 'chunks') if vm.rank == 0: os.mkdir(chunk_wd) vm.barrier() paths = {'wd': chunk_wd} gc = GridChunker(rd_src, rd_dst, nchunks_dst=[8], src_grid_resolution=resolution, dst_grid_resolution=resolution, optimized_bbox_subset=True, paths=paths, genweights=True) gc.write_chunks() dist = OcgDist() local_ctr = Dimension(name='ctr', size=8, dist=True) dist.add_dimension(local_ctr) dist.update_dimension_bounds() for ctr in range(local_ctr.bounds_local[0], local_ctr.bounds_local[1]): ctr += 1 s = os.path.join(chunk_wd, 'split_src_{}.nc'.format(ctr)) d = os.path.join(chunk_wd, 'split_dst_{}.nc'.format(ctr)) sf = Field.read(s, driver=DriverESMFUnstruct) df = Field.read(d, driver=DriverESMFUnstruct) self.assertGreater(sf.grid.shape[0], df.grid.shape[0]) wgt = os.path.join(chunk_wd, 'esmf_weights_{}.nc'.format(ctr)) f = Field.read(wgt) S = f['S'].v() self.assertAlmostEqual(S.min(), 1.0) self.assertAlmostEqual(S.max(), 1.0) with vm.scoped('merge weights', [0]): if not vm.is_null: merged_weights = self.get_temporary_file_path( 'merged_weights.nc') gc.create_merged_weight_file(merged_weights, strict=False) f = Field.read(merged_weights) S = f['S'].v() self.assertAlmostEqual(S.min(), 1.0) self.assertAlmostEqual(S.max(), 1.0)
def create_dimension(self, *args, **kwargs): from ocgis import Dimension group = kwargs.pop('group', None) dim = Dimension(*args, **kwargs) self.add_dimension(dim, group=group) # If rank counts are not the same, you have to get your own dimension. Ranks may not be contained in the # mapping. if self.size != MPI_SIZE: ret = None else: ret = self.get_dimension(dim.name, group=group) return ret
def fixture_driver_scrip_netcdf_field(self): xvalue = np.arange(10., 35., step=5) yvalue = np.arange(45., 85., step=10) grid_size = xvalue.shape[0] * yvalue.shape[0] dim_grid_size = Dimension(name='grid_size', size=grid_size) x = Variable(name='grid_center_lon', dimensions=dim_grid_size) y = Variable(name='grid_center_lat', dimensions=dim_grid_size) for idx, (xv, yv) in enumerate(itertools.product(xvalue, yvalue)): x.get_value()[idx] = xv y.get_value()[idx] = yv gc = PointGC(x=x, y=y, crs=Spherical(), driver=DriverNetcdfSCRIP) grid = GridUnstruct(geoms=[gc]) ret = Field(grid=grid, driver=DriverNetcdfSCRIP) grid_dims = Variable(name='grid_dims', value=[yvalue.shape[0], xvalue.shape[0]], dimensions='grid_rank') ret.add_variable(grid_dims) return ret
def write_subsets(self, src_template, dst_template, wgt_template, index_path): """ Write grid subsets to netCDF files using the provided filename templates. The template must contain the full file path with a single curly-bracer pair to insert the combination counter. ``wgt_template`` should not be a full path. This name is used when generating weight files. >>> template_example = '/path/to/data_{}.nc' :param str src_template: The template for the source subset file. :param str dst_template: The template for the destination subset file. :param str wgt_template: The template for the weight filename. >>> wgt_template = 'esmf_weights_{}.nc' :param index_path: Path to the output indexing netCDF. """ src_filenames = [] dst_filenames = [] wgt_filenames = [] dst_slices = [] # nzeros = len(str(reduce(lambda x, y: x * y, self.nsplits_dst))) for ctr, (sub_src, sub_dst, dst_slc) in enumerate(self.iter_src_grid_subsets(yield_dst=True), start=1): # padded = create_zero_padded_integer(ctr, nzeros) src_path = src_template.format(ctr) dst_path = dst_template.format(ctr) wgt_filename = wgt_template.format(ctr) src_filenames.append(os.path.split(src_path)[1]) dst_filenames.append(os.path.split(dst_path)[1]) wgt_filenames.append(wgt_filename) dst_slices.append(dst_slc) for target, path in zip([sub_src, sub_dst], [src_path, dst_path]): if target.is_empty: is_empty = True target = None else: is_empty = False field = Field(grid=target, is_empty=is_empty) ocgis_lh(msg='writing: {}'.format(path), level=logging.DEBUG) with vm.scoped_by_emptyable('field.write', field): if not vm.is_null: field.write(path) ocgis_lh(msg='finished writing: {}'.format(path), level=logging.DEBUG) with vm.scoped('index write', [0]): if not vm.is_null: dim = Dimension('nfiles', len(src_filenames)) vname = ['source_filename', 'destination_filename', 'weights_filename'] values = [src_filenames, dst_filenames, wgt_filenames] grid_splitter_destination = GridSplitterConstants.IndexFile.NAME_DESTINATION_VARIABLE attrs = [{'esmf_role': 'grid_splitter_source'}, {'esmf_role': grid_splitter_destination}, {'esmf_role': 'grid_splitter_weights'}] vc = VariableCollection() grid_splitter_index = GridSplitterConstants.IndexFile.NAME_INDEX_VARIABLE vidx = Variable(name=grid_splitter_index) vidx.attrs['esmf_role'] = grid_splitter_index vidx.attrs['grid_splitter_source'] = 'source_filename' vidx.attrs[GridSplitterConstants.IndexFile.NAME_DESTINATION_VARIABLE] = 'destination_filename' vidx.attrs['grid_splitter_weights'] = 'weights_filename' x_bounds = GridSplitterConstants.IndexFile.NAME_X_BOUNDS_VARIABLE vidx.attrs[x_bounds] = x_bounds y_bounds = GridSplitterConstants.IndexFile.NAME_Y_BOUNDS_VARIABLE vidx.attrs[y_bounds] = y_bounds vc.add_variable(vidx) for idx in range(len(vname)): v = Variable(name=vname[idx], dimensions=dim, dtype=str, value=values[idx], attrs=attrs[idx]) vc.add_variable(v) bounds_dimension = Dimension(name='bounds', size=2) xb = Variable(name=x_bounds, dimensions=[dim, bounds_dimension], attrs={'esmf_role': 'x_split_bounds'}, dtype=int) yb = Variable(name=y_bounds, dimensions=[dim, bounds_dimension], attrs={'esmf_role': 'y_split_bounds'}, dtype=int) x_name = self.dst_grid.x.dimensions[0].name y_name = self.dst_grid.y.dimensions[0].name for idx, slc in enumerate(dst_slices): xb.get_value()[idx, :] = slc[x_name].start, slc[x_name].stop yb.get_value()[idx, :] = slc[y_name].start, slc[y_name].stop vc.add_variable(xb) vc.add_variable(yb) vc.write(index_path) vm.barrier()
def regrid_field(source, destination, regrid_method='auto', value_mask=None, split=True, fill_value=None, weights_in=None, weights_out=None): """ Regrid ``source`` data to match the grid of ``destination``. :param source: The source field. :type source: :class:`ocgis.Field` :param destination: The destination field. :type destination: :class:`ocgis.Field` :param regrid_method: See :func:`~ocgis.regrid.base.create_esmf_grid`. :param value_mask: See :func:`~ocgis.regrid.base.iter_esmf_fields`. :type value_mask: :class:`numpy.ndarray` :param bool split: See :func:`~ocgis.regrid.base.iter_esmf_fields`. :param fill_value: Destination fill value used to fill the destination field before regridding. If ``None``, then the default fill value for the destination field data type will be used. :type fill_value: int | float :rtype: :class:`ocgis.Field` :param weights_in: Optional path to an input weights file. The route handle will be created from weights in this file. Assumes a SCRIP-like structure for the input weight file. :type weights_in: str :param weights_out: Optional path to an output weight file. Does NOT do any regridding - just writes the weights. :type weights_out: str """ # This function runs a series of asserts to make sure the sources and destination are compatible. check_fields_for_regridding(source, destination, regrid_method=regrid_method) dst_grid = destination.grid # Reference the destination grid # Spatial coordinate dimensions for the destination grid dst_spatial_coordinate_dimensions = OrderedDict([ (dim.name, dim) for dim in dst_grid.dimensions ]) # Spatial coordinate dimensions for the source grid src_spatial_coordinate_dimensions = OrderedDict([ (dim.name, dim) for dim in source.grid.dimensions ]) try: archetype = source.data_variables[ 0] # Reference an archetype data variable. except IndexError: # There may be no data variables. Use the grid as reference instead. archetype = source.grid # Extra dimensions (like time or level) to iterate over or use for ndbounds depending on the split protocol extra_dimensions = OrderedDict([ (dim.name, dim) for dim in archetype.dimensions if dim.name not in dst_spatial_coordinate_dimensions and dim.name not in src_spatial_coordinate_dimensions ]) # If there are no extra dimensions, then there is no need to split fields. if len(extra_dimensions) == 0: split = False if split: # There are no extra, ungridded dimensions for ESMF to use. ndbounds = None else: # These are the extra, ungridded dimensions for ESMF to use (ndbounds). ndbounds = [len(dim) for dim in extra_dimensions.values()] ndbounds.reverse() # Fortran order is used by ESMF # Regrid each source. ocgis_lh(logger='iter_regridded_fields', msg='starting source regrid loop', level=logging.DEBUG) build = True # Flag for first loop fills = {} # Holds destination field fill variables. # TODO: OPTIMIZE: The source and destination field objects should be reused and refilled when split=False # Main field iterator for use in the regridding loop for variable_name, src_efield, current_slice in iter_esmf_fields( source, regrid_method=regrid_method, value_mask=value_mask, split=split): # We need to generate new variables given the change in shape if variable_name not in fills: # Create the destination data variable dimensions. These are a combination of the extra dimensions and # spatial coordinate dimensions. if len(extra_dimensions) > 0: new_dimensions = list(extra_dimensions.values()) else: new_dimensions = [] new_dimensions += list(dst_grid.dimensions) # Reverse the dimensions for the creation as we are working in Fortran ordering with ESMF. new_dimensions.reverse() # Create the destination fill variable and cache it source_variable = source[variable_name] new_variable = Variable(name=variable_name, dimensions=new_dimensions, dtype=source_variable.dtype, fill_value=source_variable.fill_value, attrs=source_variable.attrs) fills[variable_name] = new_variable # Only build the ESMF/OCGIS destination grids and fields once. if build: # Build the destination grid once. ocgis_lh(logger='iter_regridded_fields', msg='before create_esmf_grid', level=logging.DEBUG) esmf_destination_grid = create_esmf_grid( destination.grid, regrid_method=regrid_method, value_mask=value_mask) # Check for corners on the destination grid. If they exist, conservative regridding is possible. if regrid_method == 'auto': if esmf_grid_has_corners( esmf_destination_grid) and esmf_grid_has_corners( src_efield.grid): regrid_method = ESMF.RegridMethod.CONSERVE else: regrid_method = None # Prepare the regridded sourced field. This amounts to exchanging the grids between the objects. regridded_source = source.copy() regridded_source.grid.extract(clean_break=True) regridded_source.set_grid(destination.grid.extract()) # Destination ESMF field dst_efield = ESMF.Field(esmf_destination_grid, name='destination', ndbounds=ndbounds) fill_variable = fills[ variable_name] # Reference the destination data variable object if fill_value is None: fv = fill_variable.fill_value # The fill value used for the variable data type else: fv = fill_value dst_efield.data.fill( fv ) # Fill the ESMF destination field with that fill value to help track masks # Construct the regrid object. Weight generation actually occurs in this call. ocgis_lh(logger='iter_regridded_fields', msg='before ESMF.Regrid', level=logging.DEBUG) if build: # Only create the regrid object once. It may be reused if split=True. if weights_in is None: if weights_out is None: create_rh = False else: create_rh = True # Create the weights and ESMF route handle from the grids regrid = ESMF.Regrid( src_efield, dst_efield, unmapped_action=ESMF.UnmappedAction.IGNORE, regrid_method=regrid_method, src_mask_values=[0], dst_mask_values=[0], filename=weights_out, create_rh=create_rh) else: # Create ESMF route handle with weights read from file regrid = ESMF.RegridFromFile(src_efield, dst_efield, weights_in) build = False ocgis_lh(logger='iter_regridded_fields', msg='after ESMF.Regrid', level=logging.DEBUG) # If we are just writing the weights file, bail out after it is written. if weights_out is not None: destroy_esmf_objects( [regrid, src_efield, dst_efield, esmf_destination_grid]) return # Perform the regrid operation. "zero_region" only fills values involved with regridding. ocgis_lh(logger='iter_regridded_fields', msg='before regrid', level=logging.DEBUG) regridded_esmf_field = regrid(src_efield, dst_efield, zero_region=ESMF.Region.SELECT) e_data = regridded_esmf_field.data # Regridded data values # These are the unmapped values coming out of the ESMF regrid operation. unmapped_mask = e_data[:] == fv # If all data is masked, raise an exception. if unmapped_mask.all(): # Destroy ESMF objects. destroy_esmf_objects([regrid, dst_efield, esmf_destination_grid]) msg = 'All regridded elements are masked. Do the input spatial extents overlap?' raise RegriddingError(msg) if current_slice is not None: # Create an OCGIS variable to use for setting on the destination. We want to use label-based slicing since # arbitrary dimensions are possible with the extra dimensions. First, set defaults for the spatial # coordinate slices. for k in dst_spatial_coordinate_dimensions.keys(): current_slice[k] = slice(None) # The spatial coordinate dimension names for ESMF in Fortran order e_data_dimensions = deepcopy( list(dst_spatial_coordinate_dimensions.keys())) e_data_dimensions.reverse() # The extra dimension names for ESMF in Fortran order e_data_dimensions_extra = deepcopy(list(extra_dimensions.keys())) e_data_dimensions_extra.reverse() # Wrap the ESMF data in an OCGIS variable e_data_var = Variable(name='e_data', value=e_data, dimensions=e_data_dimensions, mask=unmapped_mask) # Expand the new variable's dimension to account for the extra dimensions reshape_dims = list(e_data_var.dimensions) + [ Dimension(name=n, size=1) for n in e_data_dimensions_extra ] e_data_var.reshape(reshape_dims) # Set the destination fill variable with the ESMF regridded data fill_variable[current_slice] = e_data_var else: # ESMF and OCGIS dimensions align at this point, so just insert the data fill_variable.v()[:] = e_data # Create a new variable collection and add the variables to the output field. for v in list(fills.values()): regridded_source.add_variable(v, is_data=True, force=True) # Destroy ESMF objects. if weights_out is None: destroy_esmf_objects( [regrid, dst_efield, src_efield, esmf_destination_grid]) else: destroy_esmf_objects([dst_efield, src_efield, esmf_destination_grid]) # Broadcast ESMF (Fortran) ordering to Python (C) ordering. dst_names = [dim.name for dim in new_dimensions] dst_names.reverse() for data_variable in regridded_source.data_variables: broadcast_variable(data_variable, dst_names) return regridded_source
def fixture_element_dimension(self): return Dimension('elements', size=6)
def write_subsets(self): """ Write grid subsets to netCDF files using the provided filename templates. """ src_filenames = [] dst_filenames = [] wgt_filenames = [] dst_slices = [] src_slices = [] index_path = self.create_full_path_from_template('index_file') # nzeros = len(str(reduce(lambda x, y: x * y, self.nsplits_dst))) ctr = 1 for sub_src, src_slc, sub_dst, dst_slc in self.iter_src_grid_subsets(yield_dst=True): # if vm.rank == 0: # vm.rank_print('write_subset iterator count :: {}'.format(ctr)) # tstart = time.time() # padded = create_zero_padded_integer(ctr, nzeros) src_path = self.create_full_path_from_template('src_template', index=ctr) dst_path = self.create_full_path_from_template('dst_template', index=ctr) wgt_path = self.create_full_path_from_template('wgt_template', index=ctr) src_filenames.append(os.path.split(src_path)[1]) dst_filenames.append(os.path.split(dst_path)[1]) wgt_filenames.append(wgt_path) dst_slices.append(dst_slc) src_slices.append(src_slc) # Only write destinations if an iterator is not provided. if self.iter_dst is None: zip_args = [[sub_src, sub_dst], [src_path, dst_path]] else: zip_args = [[sub_src], [src_path]] for target, path in zip(*zip_args): with vm.scoped_by_emptyable('field.write', target): if not vm.is_null: ocgis_lh(msg='writing: {}'.format(path), level=logging.DEBUG) field = Field(grid=target) field.write(path) ocgis_lh(msg='finished writing: {}'.format(path), level=logging.DEBUG) # Increment the counter outside of the loop to avoid counting empty subsets. ctr += 1 # if vm.rank == 0: # tstop = time.time() # vm.rank_print('timing::write_subset iteration::{}'.format(tstop - tstart)) # Global shapes require a VM global scope to collect. src_global_shape = global_grid_shape(self.src_grid) dst_global_shape = global_grid_shape(self.dst_grid) # Gather and collapse source slices as some may be empty and we write on rank 0. gathered_src_grid_slice = vm.gather(src_slices) if vm.rank == 0: len_src_slices = len(src_slices) new_src_grid_slice = [None] * len_src_slices for idx in range(len_src_slices): for rank_src_grid_slice in gathered_src_grid_slice: if rank_src_grid_slice[idx] is not None: new_src_grid_slice[idx] = rank_src_grid_slice[idx] break src_slices = new_src_grid_slice with vm.scoped('index write', [0]): if not vm.is_null: dim = Dimension('nfiles', len(src_filenames)) vname = ['source_filename', 'destination_filename', 'weights_filename'] values = [src_filenames, dst_filenames, wgt_filenames] grid_splitter_destination = GridSplitterConstants.IndexFile.NAME_DESTINATION_VARIABLE attrs = [{'esmf_role': 'grid_splitter_source'}, {'esmf_role': grid_splitter_destination}, {'esmf_role': 'grid_splitter_weights'}] vc = VariableCollection() grid_splitter_index = GridSplitterConstants.IndexFile.NAME_INDEX_VARIABLE vidx = Variable(name=grid_splitter_index) vidx.attrs['esmf_role'] = grid_splitter_index vidx.attrs['grid_splitter_source'] = 'source_filename' vidx.attrs[GridSplitterConstants.IndexFile.NAME_DESTINATION_VARIABLE] = 'destination_filename' vidx.attrs['grid_splitter_weights'] = 'weights_filename' vidx.attrs[GridSplitterConstants.IndexFile.NAME_SRC_GRID_SHAPE] = src_global_shape vidx.attrs[GridSplitterConstants.IndexFile.NAME_DST_GRID_SHAPE] = dst_global_shape vc.add_variable(vidx) for idx in range(len(vname)): v = Variable(name=vname[idx], dimensions=dim, dtype=str, value=values[idx], attrs=attrs[idx]) vc.add_variable(v) bounds_dimension = Dimension(name='bounds', size=2) # TODO: This needs to work with four dimensions. # Source ----------------------------------------------------------------------------------------------- self.src_grid._gs_create_index_bounds_(RegriddingRole.SOURCE, vidx, vc, src_slices, dim, bounds_dimension) # Destination ------------------------------------------------------------------------------------------ self.dst_grid._gs_create_index_bounds_(RegriddingRole.DESTINATION, vidx, vc, dst_slices, dim, bounds_dimension) vc.write(index_path) vm.barrier()
def create_merged_weight_file(self, merged_weight_filename, strict=False): """ Merge weight file chunks to a single, global weight file. :param str merged_weight_filename: Path to the merged weight file. :param bool strict: If ``False``, allow "missing" files where the iterator index cannot create a found file. It is best to leave these ``False`` as not all source and destinations are mapped. If ``True``, raise an """ if vm.size > 1: raise ValueError("'create_merged_weight_file' does not work in parallel") index_filename = self.create_full_path_from_template('index_file') ifile = RequestDataset(uri=index_filename).get() ifile.load() ifc = GridSplitterConstants.IndexFile gidx = ifile[ifc.NAME_INDEX_VARIABLE].attrs src_global_shape = gidx[ifc.NAME_SRC_GRID_SHAPE] dst_global_shape = gidx[ifc.NAME_DST_GRID_SHAPE] # Get the global weight dimension size. n_s_size = 0 weight_filename = ifile[gidx[ifc.NAME_WEIGHTS_VARIABLE]] wv = weight_filename.join_string_value() split_weight_file_directory = self.paths['wd'] for wfn in map(lambda x: os.path.join(split_weight_file_directory, x), wv): if not os.path.exists(wfn): if strict: raise IOError(wfn) else: continue n_s_size += RequestDataset(wfn).get().dimensions['n_s'].size # Create output weight file. wf_varnames = ['row', 'col', 'S'] wf_dtypes = [np.int32, np.int32, np.float64] vc = VariableCollection() dim = Dimension('n_s', n_s_size) for w, wd in zip(wf_varnames, wf_dtypes): var = Variable(name=w, dimensions=dim, dtype=wd) vc.add_variable(var) vc.write(merged_weight_filename) # Transfer weights to the merged file. sidx = 0 src_indices = self.src_grid._gs_create_global_indices_(src_global_shape) dst_indices = self.dst_grid._gs_create_global_indices_(dst_global_shape) out_wds = nc.Dataset(merged_weight_filename, 'a') for ii, wfn in enumerate(map(lambda x: os.path.join(split_weight_file_directory, x), wv)): if not os.path.exists(wfn): if strict: raise IOError(wfn) else: continue wdata = RequestDataset(wfn).get() for wvn in wf_varnames: odata = wdata[wvn].get_value() try: split_grids_directory = self.paths['wd'] odata = self._gs_remap_weight_variable_(ii, wvn, odata, src_indices, dst_indices, ifile, gidx, split_grids_directory=split_grids_directory) except IndexError as e: msg = "Weight filename: '{}'; Weight Variable Name: '{}'. {}".format(wfn, wvn, e.message) raise IndexError(msg) out_wds[wvn][sidx:sidx + odata.size] = odata out_wds.sync() sidx += odata.size out_wds.close()