Exemplo n.º 1
0
    def test_create_index_variable_global(self):
        raise SkipTest('not implemented')

        dsrc_size = 5
        ddst_size = 7

        dsrc = Dimension('dsrc', dsrc_size, dist=True)
        src_dist = OcgDist()
        src_dist.add_dimension(dsrc)
        src_dist.update_dimension_bounds()

        ddst = Dimension('ddst', ddst_size, dist=True)
        dst_dist = OcgDist()
        dst_dist.add_dimension(ddst)
        dst_dist.update_dimension_bounds()

        if vm.rank == 0:
            np.random.seed(1)
            dst = np.random.rand(ddst_size)
            src = np.random.choice(dst, size=dsrc_size, replace=False)

            src = Variable(name='src', value=src, dimensions=dsrc.name)
            # TODO: move create_ugid_global to create_global_index on a standard variable object
            dst = GeometryVariable(name='dst', value=dst, dimensions=ddst.name)
        else:
            src, dst = [None] * 2

        src = variable_scatter(src, src_dist)
        dst = variable_scatter(dst, dst_dist)

        actual = create_index_variable_global('index_array', src, dst)

        self.assertNumpyAll(dst.get_value()[actual.get_value()],
                            src.get_value())
Exemplo n.º 2
0
    def test_set_spatial_mask(self):
        dmap = DimensionMap()
        dims = Dimension('x', 3), Dimension('y', 7)
        mask_var = create_spatial_mask_variable('a_mask', None, dims)
        self.assertFalse(np.any(mask_var.get_mask()))
        dmap.set_spatial_mask(mask_var)
        self.assertEqual(dmap.get_spatial_mask(), mask_var.name)

        with self.assertRaises(DimensionMapError):
            dmap.set_variable(DMK.SPATIAL_MASK, mask_var)

        # Test custom variables may be used.
        dmap = DimensionMap()
        dims = Dimension('x', 3), Dimension('y', 7)
        mask_var = create_spatial_mask_variable('a_mask', None, dims)
        attrs = {'please keep me': 'no overwriting'}
        dmap.set_spatial_mask(mask_var, attrs=attrs)
        attrs = dmap.get_attrs(DMK.SPATIAL_MASK)
        self.assertIn('please keep me', attrs)

        # Test default attributes are not added.
        dmap = DimensionMap()
        dmap.set_spatial_mask('foo', default_attrs={'blue': 'world'})
        prop = dmap.get_property(DMK.SPATIAL_MASK)
        self.assertEqual(prop['attrs'], {'blue': 'world'})
Exemplo n.º 3
0
    def test_get_live_ranks_from_object(self):
        if MPI_SIZE != 4:
            raise SkipTest('MPI_SIZE != 4')

        vm = OcgVM()

        if MPI_RANK == 1:
            dim = Dimension('woot', is_empty=True, dist=True)
        else:
            dim = Dimension('woot', dist=True, size=3)

        actual = vm.get_live_ranks_from_object(dim)
        self.assertEqual(actual, (0, 2, 3))

        vm.finalize()
Exemplo n.º 4
0
    def get_grouping(self, grouping):
        """
        Create a temporally grouped variable using string group sequences.
        
        :param grouping: The temporal grouping to use when creating the temporal group dimension.

        >>> grouping = ['month']

        :type grouping: `sequence` of :class:`str`
        :rtype: :class:`~ocgis.variable.temporal.TemporalGroupVariable`
        """

        # There is no need to go through the process of breaking out datetime parts when the grouping is 'all'.
        if grouping == 'all':
            new_bounds, date_parts, repr_dt, dgroups = self._get_grouping_all_(
            )
        # The process for getting "unique" seasons is also specialized.
        elif 'unique' in grouping:
            new_bounds, date_parts, repr_dt, dgroups = self._get_grouping_seasonal_unique_(
                grouping)
        # For standard groups ("['month']") or seasons across entire time range.
        else:
            new_bounds, date_parts, repr_dt, dgroups = self._get_grouping_other_(
                grouping)

        new_name = 'climatology_bounds'
        time_dimension_name = self.dimensions[0].name
        if self.has_bounds:
            new_dimensions = [d.name for d in self.bounds.dimensions]
        else:
            new_dimensions = [time_dimension_name, 'bounds']

        # Create the new time dimension as unlimited if the original time variable also has an unlimited dimension.
        if self.dimensions[0].is_unlimited:
            new_time_dimension = Dimension(name=time_dimension_name,
                                           size_current=len(repr_dt))
        else:
            new_time_dimension = time_dimension_name
        new_dimensions[0] = new_time_dimension

        new_bounds = TemporalVariable(value=new_bounds,
                                      name=new_name,
                                      dimensions=new_dimensions)
        new_attrs = deepcopy(self.attrs)
        # new_attrs['climatology'] = new_bounds.name
        tgv = TemporalGroupVariable(grouping=grouping,
                                    date_parts=date_parts,
                                    bounds=new_bounds,
                                    dgroups=dgroups,
                                    value=repr_dt,
                                    units=self.units,
                                    calendar=self.calendar,
                                    name=self.name,
                                    attrs=new_attrs,
                                    dimensions=new_dimensions[0])
        tgv.attrs.pop(TemporalVariable._bounds_attribute_name, None)

        return tgv
Exemplo n.º 5
0
    def test_set_spatial_mask(self):
        dmap = DimensionMap()
        dims = Dimension('x', 3), Dimension('y', 7)
        mask_var = create_grid_mask_variable('a_mask', None, dims)
        self.assertFalse(np.any(mask_var.get_mask()))
        dmap.set_spatial_mask(mask_var)
        self.assertEqual(dmap.get_spatial_mask(), mask_var.name)

        with self.assertRaises(DimensionMapError):
            dmap.set_variable(DMK.SPATIAL_MASK, mask_var)

        # Test custom variables may be used.
        dmap = DimensionMap()
        dims = Dimension('x', 3), Dimension('y', 7)
        mask_var = create_grid_mask_variable('a_mask', None, dims)
        attrs = {'please keep me': 'no overwriting'}
        dmap.set_spatial_mask(mask_var, attrs=attrs)
        attrs = dmap.get_attrs(DMK.SPATIAL_MASK)
        self.assertIn('please keep me', attrs)
Exemplo n.º 6
0
    def test_system_grid_chunking(self):
        if vm.size != 4:
            raise SkipTest('vm.size != 4')

        from ocgis.spatial.grid_chunker import GridChunker
        path = self.path_esmf_unstruct
        rd_dst = RequestDataset(uri=path,
                                driver=DriverESMFUnstruct,
                                crs=Spherical(),
                                grid_abstraction='point',
                                grid_is_isomorphic=True)
        rd_src = deepcopy(rd_dst)
        resolution = 0.28125
        chunk_wd = os.path.join(self.current_dir_output, 'chunks')
        if vm.rank == 0:
            os.mkdir(chunk_wd)
        vm.barrier()
        paths = {'wd': chunk_wd}
        gc = GridChunker(rd_src,
                         rd_dst,
                         nchunks_dst=[8],
                         src_grid_resolution=resolution,
                         dst_grid_resolution=resolution,
                         optimized_bbox_subset=True,
                         paths=paths,
                         genweights=True)
        gc.write_chunks()

        dist = OcgDist()
        local_ctr = Dimension(name='ctr', size=8, dist=True)
        dist.add_dimension(local_ctr)
        dist.update_dimension_bounds()
        for ctr in range(local_ctr.bounds_local[0], local_ctr.bounds_local[1]):
            ctr += 1
            s = os.path.join(chunk_wd, 'split_src_{}.nc'.format(ctr))
            d = os.path.join(chunk_wd, 'split_dst_{}.nc'.format(ctr))
            sf = Field.read(s, driver=DriverESMFUnstruct)
            df = Field.read(d, driver=DriverESMFUnstruct)
            self.assertGreater(sf.grid.shape[0], df.grid.shape[0])

            wgt = os.path.join(chunk_wd, 'esmf_weights_{}.nc'.format(ctr))
            f = Field.read(wgt)
            S = f['S'].v()
            self.assertAlmostEqual(S.min(), 1.0)
            self.assertAlmostEqual(S.max(), 1.0)

        with vm.scoped('merge weights', [0]):
            if not vm.is_null:
                merged_weights = self.get_temporary_file_path(
                    'merged_weights.nc')
                gc.create_merged_weight_file(merged_weights, strict=False)
                f = Field.read(merged_weights)
                S = f['S'].v()
                self.assertAlmostEqual(S.min(), 1.0)
                self.assertAlmostEqual(S.max(), 1.0)
Exemplo n.º 7
0
Arquivo: mpi.py Projeto: NCPP/ocgis
    def create_dimension(self, *args, **kwargs):
        from ocgis import Dimension

        group = kwargs.pop('group', None)
        dim = Dimension(*args, **kwargs)
        self.add_dimension(dim, group=group)

        # If rank counts are not the same, you have to get your own dimension. Ranks may not be contained in the
        # mapping.
        if self.size != MPI_SIZE:
            ret = None
        else:
            ret = self.get_dimension(dim.name, group=group)
        return ret
Exemplo n.º 8
0
    def fixture_driver_scrip_netcdf_field(self):
        xvalue = np.arange(10., 35., step=5)
        yvalue = np.arange(45., 85., step=10)
        grid_size = xvalue.shape[0] * yvalue.shape[0]

        dim_grid_size = Dimension(name='grid_size', size=grid_size)
        x = Variable(name='grid_center_lon', dimensions=dim_grid_size)
        y = Variable(name='grid_center_lat', dimensions=dim_grid_size)

        for idx, (xv, yv) in enumerate(itertools.product(xvalue, yvalue)):
            x.get_value()[idx] = xv
            y.get_value()[idx] = yv

        gc = PointGC(x=x, y=y, crs=Spherical(), driver=DriverNetcdfSCRIP)
        grid = GridUnstruct(geoms=[gc])
        ret = Field(grid=grid, driver=DriverNetcdfSCRIP)

        grid_dims = Variable(name='grid_dims',
                             value=[yvalue.shape[0], xvalue.shape[0]],
                             dimensions='grid_rank')
        ret.add_variable(grid_dims)

        return ret
Exemplo n.º 9
0
    def write_subsets(self, src_template, dst_template, wgt_template, index_path):
        """
        Write grid subsets to netCDF files using the provided filename templates. The template must contain the full
        file path with a single curly-bracer pair to insert the combination counter. ``wgt_template`` should not be a
        full path. This name is used when generating weight files.

        >>> template_example = '/path/to/data_{}.nc'

        :param str src_template: The template for the source subset file.
        :param str dst_template: The template for the destination subset file.
        :param str wgt_template: The template for the weight filename.

        >>> wgt_template = 'esmf_weights_{}.nc'

        :param index_path: Path to the output indexing netCDF.
        """

        src_filenames = []
        dst_filenames = []
        wgt_filenames = []
        dst_slices = []

        # nzeros = len(str(reduce(lambda x, y: x * y, self.nsplits_dst)))

        for ctr, (sub_src, sub_dst, dst_slc) in enumerate(self.iter_src_grid_subsets(yield_dst=True), start=1):
            # padded = create_zero_padded_integer(ctr, nzeros)

            src_path = src_template.format(ctr)
            dst_path = dst_template.format(ctr)
            wgt_filename = wgt_template.format(ctr)

            src_filenames.append(os.path.split(src_path)[1])
            dst_filenames.append(os.path.split(dst_path)[1])
            wgt_filenames.append(wgt_filename)
            dst_slices.append(dst_slc)

            for target, path in zip([sub_src, sub_dst], [src_path, dst_path]):
                if target.is_empty:
                    is_empty = True
                    target = None
                else:
                    is_empty = False
                field = Field(grid=target, is_empty=is_empty)
                ocgis_lh(msg='writing: {}'.format(path), level=logging.DEBUG)
                with vm.scoped_by_emptyable('field.write', field):
                    if not vm.is_null:
                        field.write(path)
                ocgis_lh(msg='finished writing: {}'.format(path), level=logging.DEBUG)

        with vm.scoped('index write', [0]):
            if not vm.is_null:
                dim = Dimension('nfiles', len(src_filenames))
                vname = ['source_filename', 'destination_filename', 'weights_filename']
                values = [src_filenames, dst_filenames, wgt_filenames]
                grid_splitter_destination = GridSplitterConstants.IndexFile.NAME_DESTINATION_VARIABLE
                attrs = [{'esmf_role': 'grid_splitter_source'},
                         {'esmf_role': grid_splitter_destination},
                         {'esmf_role': 'grid_splitter_weights'}]

                vc = VariableCollection()

                grid_splitter_index = GridSplitterConstants.IndexFile.NAME_INDEX_VARIABLE
                vidx = Variable(name=grid_splitter_index)
                vidx.attrs['esmf_role'] = grid_splitter_index
                vidx.attrs['grid_splitter_source'] = 'source_filename'
                vidx.attrs[GridSplitterConstants.IndexFile.NAME_DESTINATION_VARIABLE] = 'destination_filename'
                vidx.attrs['grid_splitter_weights'] = 'weights_filename'
                x_bounds = GridSplitterConstants.IndexFile.NAME_X_BOUNDS_VARIABLE
                vidx.attrs[x_bounds] = x_bounds
                y_bounds = GridSplitterConstants.IndexFile.NAME_Y_BOUNDS_VARIABLE
                vidx.attrs[y_bounds] = y_bounds
                vc.add_variable(vidx)

                for idx in range(len(vname)):
                    v = Variable(name=vname[idx], dimensions=dim, dtype=str, value=values[idx], attrs=attrs[idx])
                    vc.add_variable(v)

                bounds_dimension = Dimension(name='bounds', size=2)
                xb = Variable(name=x_bounds, dimensions=[dim, bounds_dimension], attrs={'esmf_role': 'x_split_bounds'},
                              dtype=int)
                yb = Variable(name=y_bounds, dimensions=[dim, bounds_dimension], attrs={'esmf_role': 'y_split_bounds'},
                              dtype=int)

                x_name = self.dst_grid.x.dimensions[0].name
                y_name = self.dst_grid.y.dimensions[0].name
                for idx, slc in enumerate(dst_slices):
                    xb.get_value()[idx, :] = slc[x_name].start, slc[x_name].stop
                    yb.get_value()[idx, :] = slc[y_name].start, slc[y_name].stop
                vc.add_variable(xb)
                vc.add_variable(yb)

                vc.write(index_path)

        vm.barrier()
Exemplo n.º 10
0
def regrid_field(source,
                 destination,
                 regrid_method='auto',
                 value_mask=None,
                 split=True,
                 fill_value=None,
                 weights_in=None,
                 weights_out=None):
    """
    Regrid ``source`` data to match the grid of ``destination``.

    :param source: The source field.
    :type source: :class:`ocgis.Field`
    :param destination: The destination field.
    :type destination: :class:`ocgis.Field`
    :param regrid_method: See :func:`~ocgis.regrid.base.create_esmf_grid`.
    :param value_mask: See :func:`~ocgis.regrid.base.iter_esmf_fields`.
    :type value_mask: :class:`numpy.ndarray`
    :param bool split: See :func:`~ocgis.regrid.base.iter_esmf_fields`.
    :param fill_value: Destination fill value used to fill the destination field before regridding. If ``None``, then
     the default fill value for the destination field data type will be used.
    :type fill_value: int | float
    :rtype: :class:`ocgis.Field`
    :param weights_in: Optional path to an input weights file. The route handle will be created from weights in this
     file. Assumes a SCRIP-like structure for the input weight file.
    :type weights_in: str
    :param weights_out: Optional path to an output weight file. Does NOT do any regridding - just writes the weights.
    :type weights_out: str
    """

    # This function runs a series of asserts to make sure the sources and destination are compatible.
    check_fields_for_regridding(source,
                                destination,
                                regrid_method=regrid_method)

    dst_grid = destination.grid  # Reference the destination grid
    # Spatial coordinate dimensions for the destination grid
    dst_spatial_coordinate_dimensions = OrderedDict([
        (dim.name, dim) for dim in dst_grid.dimensions
    ])
    # Spatial coordinate dimensions for the source grid
    src_spatial_coordinate_dimensions = OrderedDict([
        (dim.name, dim) for dim in source.grid.dimensions
    ])

    try:
        archetype = source.data_variables[
            0]  # Reference an archetype data variable.
    except IndexError:
        # There may be no data variables. Use the grid as reference instead.
        archetype = source.grid

    # Extra dimensions (like time or level) to iterate over or use for ndbounds depending on the split protocol
    extra_dimensions = OrderedDict([
        (dim.name, dim) for dim in archetype.dimensions
        if dim.name not in dst_spatial_coordinate_dimensions
        and dim.name not in src_spatial_coordinate_dimensions
    ])

    # If there are no extra dimensions, then there is no need to split fields.
    if len(extra_dimensions) == 0:
        split = False

    if split:
        # There are no extra, ungridded dimensions for ESMF to use.
        ndbounds = None
    else:
        # These are the extra, ungridded dimensions for ESMF to use (ndbounds).
        ndbounds = [len(dim) for dim in extra_dimensions.values()]
        ndbounds.reverse()  # Fortran order is used by ESMF

    # Regrid each source.
    ocgis_lh(logger='iter_regridded_fields',
             msg='starting source regrid loop',
             level=logging.DEBUG)
    build = True  # Flag for first loop
    fills = {}  # Holds destination field fill variables.

    # TODO: OPTIMIZE: The source and destination field objects should be reused and refilled when split=False
    # Main field iterator for use in the regridding loop
    for variable_name, src_efield, current_slice in iter_esmf_fields(
            source,
            regrid_method=regrid_method,
            value_mask=value_mask,
            split=split):
        # We need to generate new variables given the change in shape
        if variable_name not in fills:
            # Create the destination data variable dimensions. These are a combination of the extra dimensions and
            # spatial coordinate dimensions.
            if len(extra_dimensions) > 0:
                new_dimensions = list(extra_dimensions.values())
            else:
                new_dimensions = []
            new_dimensions += list(dst_grid.dimensions)

            # Reverse the dimensions for the creation as we are working in Fortran ordering with ESMF.
            new_dimensions.reverse()

            # Create the destination fill variable and cache it
            source_variable = source[variable_name]
            new_variable = Variable(name=variable_name,
                                    dimensions=new_dimensions,
                                    dtype=source_variable.dtype,
                                    fill_value=source_variable.fill_value,
                                    attrs=source_variable.attrs)
            fills[variable_name] = new_variable

        # Only build the ESMF/OCGIS destination grids and fields once.
        if build:
            # Build the destination grid once.
            ocgis_lh(logger='iter_regridded_fields',
                     msg='before create_esmf_grid',
                     level=logging.DEBUG)
            esmf_destination_grid = create_esmf_grid(
                destination.grid,
                regrid_method=regrid_method,
                value_mask=value_mask)

            # Check for corners on the destination grid. If they exist, conservative regridding is possible.
            if regrid_method == 'auto':
                if esmf_grid_has_corners(
                        esmf_destination_grid) and esmf_grid_has_corners(
                            src_efield.grid):
                    regrid_method = ESMF.RegridMethod.CONSERVE
                else:
                    regrid_method = None

            # Prepare the regridded sourced field. This amounts to exchanging the grids between the objects.
            regridded_source = source.copy()
            regridded_source.grid.extract(clean_break=True)
            regridded_source.set_grid(destination.grid.extract())

        # Destination ESMF field
        dst_efield = ESMF.Field(esmf_destination_grid,
                                name='destination',
                                ndbounds=ndbounds)
        fill_variable = fills[
            variable_name]  # Reference the destination data variable object

        if fill_value is None:
            fv = fill_variable.fill_value  # The fill value used for the variable data type
        else:
            fv = fill_value

        dst_efield.data.fill(
            fv
        )  # Fill the ESMF destination field with that fill value to help track masks

        # Construct the regrid object. Weight generation actually occurs in this call.
        ocgis_lh(logger='iter_regridded_fields',
                 msg='before ESMF.Regrid',
                 level=logging.DEBUG)

        if build:  # Only create the regrid object once. It may be reused if split=True.
            if weights_in is None:
                if weights_out is None:
                    create_rh = False
                else:
                    create_rh = True
                # Create the weights and ESMF route handle from the grids
                regrid = ESMF.Regrid(
                    src_efield,
                    dst_efield,
                    unmapped_action=ESMF.UnmappedAction.IGNORE,
                    regrid_method=regrid_method,
                    src_mask_values=[0],
                    dst_mask_values=[0],
                    filename=weights_out,
                    create_rh=create_rh)
            else:
                # Create ESMF route handle with weights read from file
                regrid = ESMF.RegridFromFile(src_efield, dst_efield,
                                             weights_in)
            build = False
        ocgis_lh(logger='iter_regridded_fields',
                 msg='after ESMF.Regrid',
                 level=logging.DEBUG)

        # If we are just writing the weights file, bail out after it is written.
        if weights_out is not None:
            destroy_esmf_objects(
                [regrid, src_efield, dst_efield, esmf_destination_grid])
            return

        # Perform the regrid operation. "zero_region" only fills values involved with regridding.
        ocgis_lh(logger='iter_regridded_fields',
                 msg='before regrid',
                 level=logging.DEBUG)
        regridded_esmf_field = regrid(src_efield,
                                      dst_efield,
                                      zero_region=ESMF.Region.SELECT)
        e_data = regridded_esmf_field.data  # Regridded data values

        # These are the unmapped values coming out of the ESMF regrid operation.
        unmapped_mask = e_data[:] == fv

        # If all data is masked, raise an exception.
        if unmapped_mask.all():
            # Destroy ESMF objects.
            destroy_esmf_objects([regrid, dst_efield, esmf_destination_grid])
            msg = 'All regridded elements are masked. Do the input spatial extents overlap?'
            raise RegriddingError(msg)

        if current_slice is not None:
            # Create an OCGIS variable to use for setting on the destination. We want to use label-based slicing since
            # arbitrary dimensions are possible with the extra dimensions. First, set defaults for the spatial
            # coordinate slices.
            for k in dst_spatial_coordinate_dimensions.keys():
                current_slice[k] = slice(None)
            # The spatial coordinate dimension names for ESMF in Fortran order
            e_data_dimensions = deepcopy(
                list(dst_spatial_coordinate_dimensions.keys()))
            e_data_dimensions.reverse()
            # The extra dimension names for ESMF in Fortran order
            e_data_dimensions_extra = deepcopy(list(extra_dimensions.keys()))
            e_data_dimensions_extra.reverse()
            # Wrap the ESMF data in an OCGIS variable
            e_data_var = Variable(name='e_data',
                                  value=e_data,
                                  dimensions=e_data_dimensions,
                                  mask=unmapped_mask)
            # Expand the new variable's dimension to account for the extra dimensions
            reshape_dims = list(e_data_var.dimensions) + [
                Dimension(name=n, size=1) for n in e_data_dimensions_extra
            ]
            e_data_var.reshape(reshape_dims)
            # Set the destination fill variable with the ESMF regridded data
            fill_variable[current_slice] = e_data_var
        else:
            # ESMF and OCGIS dimensions align at this point, so just insert the data
            fill_variable.v()[:] = e_data

        # Create a new variable collection and add the variables to the output field.
        for v in list(fills.values()):
            regridded_source.add_variable(v, is_data=True, force=True)

    # Destroy ESMF objects.
    if weights_out is None:
        destroy_esmf_objects(
            [regrid, dst_efield, src_efield, esmf_destination_grid])
    else:
        destroy_esmf_objects([dst_efield, src_efield, esmf_destination_grid])

    # Broadcast ESMF (Fortran) ordering to Python (C) ordering.
    dst_names = [dim.name for dim in new_dimensions]
    dst_names.reverse()
    for data_variable in regridded_source.data_variables:
        broadcast_variable(data_variable, dst_names)

    return regridded_source
Exemplo n.º 11
0
 def fixture_element_dimension(self):
     return Dimension('elements', size=6)
Exemplo n.º 12
0
    def write_subsets(self):
        """
        Write grid subsets to netCDF files using the provided filename templates.
        """
        src_filenames = []
        dst_filenames = []
        wgt_filenames = []
        dst_slices = []
        src_slices = []
        index_path = self.create_full_path_from_template('index_file')

        # nzeros = len(str(reduce(lambda x, y: x * y, self.nsplits_dst)))

        ctr = 1
        for sub_src, src_slc, sub_dst, dst_slc in self.iter_src_grid_subsets(yield_dst=True):
            # if vm.rank == 0:
            #     vm.rank_print('write_subset iterator count :: {}'.format(ctr))
            #     tstart = time.time()
            # padded = create_zero_padded_integer(ctr, nzeros)

            src_path = self.create_full_path_from_template('src_template', index=ctr)
            dst_path = self.create_full_path_from_template('dst_template', index=ctr)
            wgt_path = self.create_full_path_from_template('wgt_template', index=ctr)

            src_filenames.append(os.path.split(src_path)[1])
            dst_filenames.append(os.path.split(dst_path)[1])
            wgt_filenames.append(wgt_path)
            dst_slices.append(dst_slc)
            src_slices.append(src_slc)

            # Only write destinations if an iterator is not provided.
            if self.iter_dst is None:
                zip_args = [[sub_src, sub_dst], [src_path, dst_path]]
            else:
                zip_args = [[sub_src], [src_path]]

            for target, path in zip(*zip_args):
                with vm.scoped_by_emptyable('field.write', target):
                    if not vm.is_null:
                        ocgis_lh(msg='writing: {}'.format(path), level=logging.DEBUG)
                        field = Field(grid=target)
                        field.write(path)
                        ocgis_lh(msg='finished writing: {}'.format(path), level=logging.DEBUG)

            # Increment the counter outside of the loop to avoid counting empty subsets.
            ctr += 1

            # if vm.rank == 0:
            #     tstop = time.time()
            #     vm.rank_print('timing::write_subset iteration::{}'.format(tstop - tstart))

        # Global shapes require a VM global scope to collect.
        src_global_shape = global_grid_shape(self.src_grid)
        dst_global_shape = global_grid_shape(self.dst_grid)

        # Gather and collapse source slices as some may be empty and we write on rank 0.
        gathered_src_grid_slice = vm.gather(src_slices)
        if vm.rank == 0:
            len_src_slices = len(src_slices)
            new_src_grid_slice = [None] * len_src_slices
            for idx in range(len_src_slices):
                for rank_src_grid_slice in gathered_src_grid_slice:
                    if rank_src_grid_slice[idx] is not None:
                        new_src_grid_slice[idx] = rank_src_grid_slice[idx]
                        break
            src_slices = new_src_grid_slice

        with vm.scoped('index write', [0]):
            if not vm.is_null:
                dim = Dimension('nfiles', len(src_filenames))
                vname = ['source_filename', 'destination_filename', 'weights_filename']
                values = [src_filenames, dst_filenames, wgt_filenames]
                grid_splitter_destination = GridSplitterConstants.IndexFile.NAME_DESTINATION_VARIABLE
                attrs = [{'esmf_role': 'grid_splitter_source'},
                         {'esmf_role': grid_splitter_destination},
                         {'esmf_role': 'grid_splitter_weights'}]

                vc = VariableCollection()

                grid_splitter_index = GridSplitterConstants.IndexFile.NAME_INDEX_VARIABLE
                vidx = Variable(name=grid_splitter_index)
                vidx.attrs['esmf_role'] = grid_splitter_index
                vidx.attrs['grid_splitter_source'] = 'source_filename'
                vidx.attrs[GridSplitterConstants.IndexFile.NAME_DESTINATION_VARIABLE] = 'destination_filename'
                vidx.attrs['grid_splitter_weights'] = 'weights_filename'
                vidx.attrs[GridSplitterConstants.IndexFile.NAME_SRC_GRID_SHAPE] = src_global_shape
                vidx.attrs[GridSplitterConstants.IndexFile.NAME_DST_GRID_SHAPE] = dst_global_shape

                vc.add_variable(vidx)

                for idx in range(len(vname)):
                    v = Variable(name=vname[idx], dimensions=dim, dtype=str, value=values[idx], attrs=attrs[idx])
                    vc.add_variable(v)

                bounds_dimension = Dimension(name='bounds', size=2)
                # TODO: This needs to work with four dimensions.
                # Source -----------------------------------------------------------------------------------------------
                self.src_grid._gs_create_index_bounds_(RegriddingRole.SOURCE, vidx, vc, src_slices, dim,
                                                       bounds_dimension)

                # Destination ------------------------------------------------------------------------------------------
                self.dst_grid._gs_create_index_bounds_(RegriddingRole.DESTINATION, vidx, vc, dst_slices, dim,
                                                       bounds_dimension)

                vc.write(index_path)

        vm.barrier()
Exemplo n.º 13
0
    def create_merged_weight_file(self, merged_weight_filename, strict=False):
        """
        Merge weight file chunks to a single, global weight file.

        :param str merged_weight_filename: Path to the merged weight file.
        :param bool strict: If ``False``, allow "missing" files where the iterator index cannot create a found file.
         It is best to leave these ``False`` as not all source and destinations are mapped. If ``True``, raise an
        """

        if vm.size > 1:
            raise ValueError("'create_merged_weight_file' does not work in parallel")

        index_filename = self.create_full_path_from_template('index_file')
        ifile = RequestDataset(uri=index_filename).get()
        ifile.load()
        ifc = GridSplitterConstants.IndexFile
        gidx = ifile[ifc.NAME_INDEX_VARIABLE].attrs

        src_global_shape = gidx[ifc.NAME_SRC_GRID_SHAPE]
        dst_global_shape = gidx[ifc.NAME_DST_GRID_SHAPE]

        # Get the global weight dimension size.
        n_s_size = 0
        weight_filename = ifile[gidx[ifc.NAME_WEIGHTS_VARIABLE]]
        wv = weight_filename.join_string_value()
        split_weight_file_directory = self.paths['wd']
        for wfn in map(lambda x: os.path.join(split_weight_file_directory, x), wv):
            if not os.path.exists(wfn):
                if strict:
                    raise IOError(wfn)
                else:
                    continue
            n_s_size += RequestDataset(wfn).get().dimensions['n_s'].size

        # Create output weight file.
        wf_varnames = ['row', 'col', 'S']
        wf_dtypes = [np.int32, np.int32, np.float64]
        vc = VariableCollection()
        dim = Dimension('n_s', n_s_size)
        for w, wd in zip(wf_varnames, wf_dtypes):
            var = Variable(name=w, dimensions=dim, dtype=wd)
            vc.add_variable(var)
        vc.write(merged_weight_filename)

        # Transfer weights to the merged file.
        sidx = 0
        src_indices = self.src_grid._gs_create_global_indices_(src_global_shape)
        dst_indices = self.dst_grid._gs_create_global_indices_(dst_global_shape)

        out_wds = nc.Dataset(merged_weight_filename, 'a')
        for ii, wfn in enumerate(map(lambda x: os.path.join(split_weight_file_directory, x), wv)):
            if not os.path.exists(wfn):
                if strict:
                    raise IOError(wfn)
                else:
                    continue
            wdata = RequestDataset(wfn).get()
            for wvn in wf_varnames:
                odata = wdata[wvn].get_value()
                try:
                    split_grids_directory = self.paths['wd']
                    odata = self._gs_remap_weight_variable_(ii, wvn, odata, src_indices, dst_indices, ifile, gidx,
                                                            split_grids_directory=split_grids_directory)
                except IndexError as e:
                    msg = "Weight filename: '{}'; Weight Variable Name: '{}'. {}".format(wfn, wvn, e.message)
                    raise IndexError(msg)
                out_wds[wvn][sidx:sidx + odata.size] = odata
                out_wds.sync()
            sidx += odata.size
        out_wds.close()