def _initialize_section_trsp_data_array(cds):
    """Create an xarray DataArray with time, depth, and latitude dims

    cds : xarray Dataset
        contains LLC coordinates 'k' and (optionally) 'time'

    ds_out : xarray Dataset
        Dataset with the variables
                zero-valued DataArray with time (optional) and 
                depth dimensions
                the original depth coordinate

    coords = OrderedDict()
    dims = ()

    if 'time' in cds:
        coords.update( {'time': cds['time'].values} )
        dims += ('time',)
        zeros = np.zeros((len(cds['time'].values),
        zeros = np.zeros((len(cds['k'].values)))

    coords.update( {'k': cds['k'].values} )

    dims += ('k',)

    xda = xr.DataArray(data=zeros, coords=coords, dims=dims)

    # Convert to dataset to add Z coordinate
    xds = xda.to_dataset(name='trsp_z')
    xds['Z'] = cds['Z']
    xds = xds.set_coords('Z')

    return xds
def _initialize_rho_trsp_dataset(cds, rho, lat_vals=None):
    """Create an xarray Dataset with time, depth, and latitude dims

    ds : xarray Dataset
        Must contain the coordinates 'k' and (optionally) 'time'
    rho : xarray DataArray
        Containing the density field to be binned and made into our new vertical coordinate
    lat_vals : int or array of ints, optional
        latitude value(s) rounded to the nearest degree
        specifying where to compute transport

    ds : xarray Dataset
        zero-valued Dataset with time, depth, and latitude dimensions

    # Create density bins
    rho_bin_edges, rho_bin_centers = get_rho_bins(rho.min().values,
    Nrho = len(rho_bin_centers)
    k_rho = np.arange(Nrho)
    k_rho_f = np.arange(len(rho_bin_edges))

    coords = OrderedDict()
    dims = ()

    if 'time' in cds:
        coords.update({'time': cds['time'].values})
        dims += ('time', )
        if lat_vals is not None:
            zeros = np.zeros((len(cds['time'].values), Nrho, len(lat_vals)))
            zeros = np.zeros((len(cds['time'].values), Nrho))
        if lat_vals is not None:
            zeros = np.zeros((Nrho, len(lat_vals)))
            zeros = np.zeros((Nrho))

    coords.update({'k_rho': k_rho})
    dims += ('k_rho', )
    if lat_vals is not None:
        coords.update({'lat': lat_vals})
        dims += ('lat', )

    da = xr.DataArray(data=zeros, coords=coords, dims=dims)

    # This could be much cleaner, and should mirror the
    # xgcm notation.
    ds = da.to_dataset(name='trsp')
    ds['rho_c'] = rho_bin_centers
    ds['rho_f'] = rho_bin_edges
    ds['k_rho_f'] = k_rho_f

    return ds
def _initialize_trsp_data_array(cds, lat_vals):
    """Create an xarray DataArray with time, depth, and latitude dims

    cds : xarray Dataset
        contains LLC coordinates 'k' and (optionally) 'time'
    lat_vals : int or array of ints
        latitude value(s) rounded to the nearest degree
        specifying where to compute transport

    ds_out : xarray Dataset
        Dataset with the variables
                zero-valued DataArray with time (optional), 
                depth, and latitude dimensions
                the original depth coordinate

    coords = OrderedDict()
    dims = ()

    if 'time' in cds:
        coords.update({'time': cds['time'].values})
        dims += ('time', )
        zeros = np.zeros(
            (len(cds['time'].values), len(cds['k'].values), len(lat_vals)))
        zeros = np.zeros((len(cds['k'].values), len(lat_vals)))

    coords.update({'k': cds['k'].values})
    coords.update({'lat': lat_vals})

    dims += ('k', 'lat')

    xda = xr.DataArray(data=zeros, coords=coords, dims=dims)

    # Convert to dataset to add Z coordinate
    xds = xda.to_dataset(name='trsp_z')
    xds['Z'] = cds['Z']
    xds = xds.set_coords('Z')

    return xds
class BPCHFile(object):
    """ A file object for representing BPCH data on disk

    fp : FortranFile
        A pointer to the open unformatted Fortran binary output (the original
        bpch file)
    var_data, var_attrs : dict
        Containers of `BPCHDataBundle`s and dicts, respectively, holding
        the accessor functions to the raw bpch data and their associated


    def __init__(self, filename, mode='rb', endian='>',
                 diaginfo_file='', tracerinfo_file='', eager=False,
                 use_mmap=False, dask_delayed=False):
        """ Load a BPCHFile

        filename : str
            Path to the bpch file on disk
        mode : str
            Mode string to pass to the file opener; this is currently fixed to
            "rb" and all other values will be rejected
        endian : str {">", "<", ":"}
            Endian-ness of the Fortran output file
        {tracerinfo, diaginfo}_file : str
            Path to the tracerinfo.dat and diaginfo.dat files containing
            metadata pertaining to the output in the bpch file being read.
        eager : bool
            Flag to immediately read variable data; if "False", then nothing
            will be read from the file and you'll need to do so manually
        use_mmap : bool
            Use memory-mapping to read data from file
        dask_delayed : bool
            Use dask to create delayed references to the data-reading functions

        self.mode = mode
        if not mode.startswith('r'):
            raise ValueError("Currently only know how to 'r(b)'ead bpch files.")

        self.filename = filename
        self.fsize = os.path.getsize(self.filename)
        self.endian = endian

        # Open a pointer to the file
        self.fp = FortranFile(self.filename, self.mode, self.endian)

        dir_path = os.path.abspath(os.path.dirname(filename))
        if not dir_path:
            dir_path = os.getcwd()
        if not tracerinfo_file:
            tracerinfo_file = os.path.join(dir_path, "tracerinfo.dat")
            if not os.path.exists(tracerinfo_file):
                tracerinfo_file = ''
        self.tracerinfo_file = tracerinfo_file
        if not diaginfo_file:
            diaginfo_file = os.path.join(dir_path, "diaginfo.dat")
            if not os.path.exists(diaginfo_file):
                diaginfo_file = ''
        self.diaginfo_file = diaginfo_file

        # Container to record file metadata
        self._attributes = OrderedDict()

        # Don't necessarily need to save diag/tracer_dict yet
        self.diaginfo_df, _ = get_diaginfo(self.diaginfo_file)
        self.tracerinfo_df, _ = get_tracerinfo(self.tracerinfo_file)

        # Container for bundles contained in the output file.
        self.var_data = {}
        self.var_attrs = {}

        # Critical information for accessing file contents
        self._header_pos = None

        # Data loading strategy
        self.use_mmap = use_mmap
        self.dask_delayed = dask_delayed

        # Control eager versus deferring reading
        self.eager = eager
        if (mode.startswith('r') and self.eager):

    def close(self):
        """ Close this bpch file.


        if not self.fp.closed:
            for v in list(self.var_data):
                del self.var_data[v]


    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):

    def _read(self):
        """ Parse the entire bpch file on disk and set up easy access to meta-
        and data blocks.



    def _read_metadata(self):
        """ Read the main metadata packaged within a bpch file, indicating
        the output filetype and its title.


        filetype = self.fp.readline().strip()
        filetitle = self.fp.readline().strip()
        # Decode to UTF string, if possible
            filetype = str(filetype, 'utf-8')
            filetitle = str(filetitle, 'utf-8')
            # TODO: Handle this edge-case of converting file metadata more elegantly.

        self.__setattr__('filetype', filetype)
        self.__setattr__('filetitle', filetitle)

    def _read_header(self):
        """ Process the header information (data model / grid spec) """

        self._header_pos = self.fp.tell()

        line = self.fp.readline('20sffii')
        modelname, res0, res1, halfpolar, center180 = line
            "modelname": str(modelname, 'utf-8').strip(),
            "halfpolar": halfpolar,
            "center180": center180,
            "res": (res0, res1)
        self.__setattr__('modelname', modelname)
        self.__setattr__('res', (res0, res1))
        self.__setattr__('halfpolar', halfpolar)
        self.__setattr__('center180', center180)

        # Re-wind the file

    def _read_var_data(self):
        """ Iterate over the block of this bpch file and return handlers
        in the form of `BPCHDataBundle`s for access to the data contained


        var_bundles = OrderedDict()
        var_attrs = OrderedDict()

        n_vars = 0

        while self.fp.tell() < self.fsize:

            var_attr = OrderedDict()

            # read first and second header lines
            line = self.fp.readline('20sffii')
            modelname, res0, res1, halfpolar, center180 = line

            line = self.fp.readline('40si40sdd40s7i')
            category_name, number, unit, tau0, tau1, reserved = line[:6]
            dim0, dim1, dim2, dim3, dim4, dim5, skip = line[6:]
            var_attr['number'] = number

            # Decode byte-strings to utf-8
            category_name = str(category_name, 'utf-8')
            var_attr['category'] = category_name.strip()
            unit = str(unit, 'utf-8')

            # get additional metadata from tracerinfo / diaginfo
                cat_df = self.diaginfo_df[
                    self.diaginfo_df.name == category_name.strip()
                # TODO: Safer logic for handling case where more than one
                #       tracer metadata match was made
                # if len(cat_df > 1):
                #     raise ValueError(
                #         "More than one category matching {} found in "
                #         "diaginfo.dat".format(
                #             category_name.strip()
                #         )
                #     )
                # Safe now to select the only row in the DataFrame
                cat = cat_df.T.squeeze()

                tracer_num = int(cat.offset) + int(number)
                diag_df = self.tracerinfo_df[
                    self.tracerinfo_df.tracer == tracer_num
                # TODO: Safer logic for handling case where more than one
                #       tracer metadata match was made
                # if len(diag_df > 1):
                #     raise ValueError(
                #         "More than one tracer matching {:d} found in "
                #         "tracerinfo.dat".format(tracer_num)
                #     )
                # Safe now to select only row in the DataFrame
                diag = diag_df.T.squeeze()
                diag_attr = diag.to_dict()

                if not unit.strip():  # unit may be empty in bpch
                    unit = diag_attr['unit']  # but not in tracerinfo
                diag = {'name': '', 'scale': 1}
            var_attr['unit'] = unit

            vname = diag['name']
            fullname = category_name.strip() + "_" + vname

            # parse metadata, get data or set a data proxy
            if dim2 == 1:
                data_shape = (dim0, dim1)         # 2D field
                data_shape = (dim0, dim1, dim2)
            var_attr['original_shape'] = data_shape

            # Add proxy time dimension to shape
            data_shape = tuple([1, ] + list(data_shape))
            origin = (dim3, dim4, dim5)
            var_attr['origin'] = origin

            timelo, timehi = cf.tau2time(tau0), cf.tau2time(tau1)

            pos = self.fp.tell()
            # Note that we don't pass a dtype, and assume everything is
            # single-fp floats with the correct endian, as hard-coded
            var_bundle = BPCHDataBundle(
                data_shape, self.endian, self.filename, pos, [timelo, timehi],
                use_mmap=self.use_mmap, dask_delayed=self.dask_delayed

            # Save the data as a "bundle" for concatenating in the final step
            if fullname in var_bundles:
                var_bundles[fullname] = [var_bundle, ]
                var_attrs[fullname] = var_attr
                n_vars += 1

        self.var_data = var_bundles
        self.var_attrs = var_attrs
class BPCHDataStore(AbstractDataStore):
    """ Store for reading data from binary punch files.

    Note that this is intended as a backend only; to open and read a given
    bpch file, use :meth:`open_bpchdataset`.

    Examples of other extensions using the core DataStore API can be found at:

    - https://github.com/pydata/xarray/blob/master/xarray/conventions.py
    - https://github.com/xgcm/xmitgcm/blob/master/xmitgcm/mds_store.py

    def __init__(self,

        # Track the metadata accompanying this dataset.
        dir_path = os.path.abspath(os.path.dirname(filename))
        if not dir_path:
            dir_path = os.getcwd()
        if not tracerinfo_file:
            tracerinfo_file = os.path.join(dir_path, 'tracerinfo.dat')
            if not os.path.exists(tracerinfo_file):
                tracerinfo_file = ''
        self.tracerinfo_file = tracerinfo_file
        if not diaginfo_file:
            diaginfo_file = os.path.join(dir_path, 'diaginfo.dat')
            if not os.path.exists(diaginfo_file):
                diaginfo_file = ''
        self.diaginfo_file = diaginfo_file

        self.filename = filename
        self.fsize = os.path.getsize(self.filename)
        self.mode = mode
        if not mode.startswith('r'):
            raise ValueError(
                "Currently only know how to 'r(b)'ead bpch files.")

        # Check endianness flag
        if endian not in ['>', '<', '=']:
            raise ValueError("Invalid byte order (endian={})".format(endian))
        self.endian = endian

        # Open the raw output file, but don't yet read all the data
        self._mmap = use_mmap
        self._dask = dask_delayed
        self._bpch = BPCHFile(self.filename,
        self.fields = fields
        self.categories = categories

        # Peek into the raw output file and read the header and metadata
        # so that we can get a head start at building the output grid

        # Parse the binary file and prepare to add variables to the DataStore

        # Create storage dicts for variables and attributes, to be used later
        # when xarray needs to access the data
        self._variables = OrderedDict()
        self._attributes = OrderedDict()
        self._dimensions = [d for d in BASE_DIMENSIONS]

        # Begin constructing the coordinate dimensions shared by the
        # output dataset variables
        dim_coords = {}
        self.ctm_info = CTMGrid.from_model(self._attributes['modelname'],

        # Add vertical dimensions
        ], attrs={'axis': 'Z'}))
        ], attrs={'axis': 'Z'}))
        ], attrs={'axis': 'Z'}))
        eta_centers = self.ctm_info.eta_centers
        sigma_centers = self.ctm_info.sigma_centers

        # Add time dimensions
                     'axis': 'T',
                     'long_name': 'time',
                     'standard_name': 'time'

        # Add lat/lon dimensions
                     'axis': 'X',
                     'long_name': 'longitude coordinate',
                     'standard_name': 'longitude'
                     'axis': 'y',
                     'long_name': 'latitude coordinate',
                     'standard_name': 'latitude'

        if eta_centers is not None:
            lev_vals = eta_centers
            lev_attrs = {
                'standard_name': 'atmosphere_hybrid_sigma_pressure_coordinate',
                'axis': 'Z'
            lev_vals = sigma_centers
            lev_attrs = {
                'standard_name': 'atmosphere_hybrid_sigma_pressure_coordinate',
                'axis': 'Z'
        self._variables['lev'] = xr.Variable([
        ], lev_vals, lev_attrs)

        ## Latitude / Longitude
        # TODO: Add lon/lat bounds

        # Detect if we're on a nested grid; in that case, we'll have a displaced
        # origin set in the variable attributes we previously read
        ref_key = list(self._bpch.var_attrs.keys())[0]
        ref_attrs = self._bpch.var_attrs[ref_key]
        self.is_nested = (ref_attrs['origin'] != (1, 1, 1))

        lon_centers = self.ctm_info.lon_centers
        lat_centers = self.ctm_info.lat_centers

        if self.is_nested:
            ix, iy, _ = ref_attrs['origin']
            nx, ny, *_ = ref_attrs['original_shape']
            # Correct i{x,y} for IDL->Python indexing (1-indexed -> 0-indexed)
            ix -= 1
            iy -= 1
            lon_centers = lon_centers[ix:ix + nx]
            lat_centers = lat_centers[iy:iy + ny]

        self._variables['lon'] = xr.Variable(['lon'], lon_centers, {
            'long_name': 'longitude',
            'units': 'degrees_east'
        self._variables['lat'] = xr.Variable(['lat'], lat_centers, {
            'long_name': 'latitude',
            'units': 'degrees_north'
        # TODO: Fix longitudes if ctm_grid.center180

        # Add variables from the parsed BPCH file to our DataStore
        for vname in list(self._bpch.var_data.keys()):

            var_data = self._bpch.var_data[vname]
            var_attr = self._bpch.var_attrs[vname]

            if fields and (var_attr['name'] not in fields):
            if categories and (var_attr['category'] not in categories):

            # Process dimensions
            dims = [
            dshape = var_attr['original_shape']
            if len(dshape) == 3:
                # Process the vertical coordinate. A few things can happen here:
                # 1) We have cell-centered values on the "Nlayer" grid; we can take these variables and map them to 'lev'
                # 2) We have edge value on an "Nlayer" + 1 grid; we can take these and use them with 'lev_edge'
                # 3) We have troposphere values on "Ntrop"; we can take these and use them with 'lev_trop', but we won't have coordinate information yet
                # All other cases we do not handle yet; this includes the aircraft emissions and a few other things. Note that tracer sources do not have a vertical coord to worry about!
                nlev = dshape[-1]
                grid_nlev = self.ctm_info.Nlayers
                grid_ntrop = self.ctm_info.Ntrop
                    if nlev == grid_nlev:
                    elif nlev == grid_nlev + 1:
                    elif nlev == grid_ntrop:
                except AttributeError:
                    warnings.warn("Couldn't resolve grid_spec vertical layout")

            # xarray Variables are thin wrappers for numpy.ndarrays, or really
            # any object that extends the ndarray interface. A critical part of
            # the original ndarray interface is that the underlying data has to
            # be contiguous in memory. We can enforce this to happen by
            # concatenating each bundle in the variable data bundles we read
            # from the bpch file
            data = self._concat([v.data for v in var_data])

            # Is the variable time-invariant? If it is, kill the time dim.
            # Here, we mean it only as one sample in the dataset.
            if data.shape[0] == 1:
                dims = dims[1:]
                data = data.squeeze()

            # Create a variable containing this data
            var = xr.Variable(dims, data, var_attr)

            # Shuffle dims for CF/COARDS compliance if requested
            # TODO: For this to work, we have to force a load of the data.
            #       Is there a way to re-write BPCHDataProxy so that that's not
            #       necessary?
            #       Actually, we can't even force a load becase var.data is a
            #       numpy.ndarray. Weird.
            # if fix_dims:
            #     target_dims = [d for d in DIM_ORDER_PRIORITY if d in dims]
            #     var = var.transpose(*target_dims)

            self._variables[vname] = var

            # Try to add a time dimension
            # TODO: Time units?
            if (len(var_data) > 1) and 'time' not in self._variables:
                time_bnds = np.asarray([v.time for v in var_data])
                times = time_bnds[:, 0]

                self._variables['time'] = xr.Variable(
                    ], times, {
                        'bounds': 'time_bnds',
                        'units': cf.CTM_TIME_UNIT_STR
                self._variables['time_bnds'] = xr.Variable(
                    ['time', 'nv'], time_bnds, {'units': cf.CTM_TIME_UNIT_STR})
                self._variables['nv'] = xr.Variable([
                ], [0, 1])

        # Create the dimension variables; we have a lot of options
        # here with regards to the vertical coordinate. For now,
        # we'll just use the sigma or eta coordinates.
        # Useful CF info: http://cfconventions.org/cf-conventions/v1.6.0/cf-conventions.html#_atmosphere_hybrid_sigma_pressure_coordinate
        # self._variables['Ap'] =
        # self._variables['Bp'] =
        # self._variables['altitude'] =

        # Time dimensions
        # self._times = self.ds.times
        # self._time_bnds = self.ds.time_bnds

    def _concat(self, *args, **kwargs):
        if self._dask:
            return da.concatenate(*args, **kwargs)
            return np.concatenate(*args, **kwargs)

    def get_variables(self):
        return self._variables

    def get_attrs(self):
        return Frozen(self._attributes)

    def get_dimensions(self):
        return Frozen(self._dimensions)

    def close(self):
        for var in list(self._variables):
            del self._variables[var]

    def __exit__(self, type, value, traceback):