Beispiel #1
0
 def __init__(self, model_path, expt, ensemble_member, max_lon=180):
     # Get path to one file on the tracer grid
     cmip_file = find_cmip6_files(model_path, expt, ensemble_member,
                                  'thetao', 'Omon')[0][0]
     self.lon_2d = fix_lon_range(read_netcdf(cmip_file, 'longitude'),
                                 max_lon=max_lon)
     self.lat_2d = read_netcdf(cmip_file, 'latitude')
     self.z = -1 * read_netcdf(cmip_file, 'lev')
     self.mask = read_netcdf(cmip_file, 'thetao', time_index=0).mask
     # And one on the u-grid
     cmip_file_u = find_cmip6_files(model_path, expt, ensemble_member, 'uo',
                                    'Omon')[0][0]
     self.lon_u_2d = fix_lon_range(read_netcdf(cmip_file_u, 'longitude'),
                                   max_lon=max_lon)
     self.lat_u_2d = read_netcdf(cmip_file_u, 'latitude')
     self.mask_u = read_netcdf(cmip_file_u, 'uo', time_index=0).mask
     # And one on the v-grid
     cmip_file_v = find_cmip6_files(model_path, expt, ensemble_member, 'vo',
                                    'Omon')[0][0]
     self.lon_v_2d = fix_lon_range(read_netcdf(cmip_file_v, 'longitude'),
                                   max_lon=max_lon)
     self.lat_v_2d = read_netcdf(cmip_file_v, 'latitude')
     self.mask_v = read_netcdf(cmip_file_v, 'vo', time_index=0).mask
     # Save grid dimensions too
     self.nx = self.lon_2d.shape[1]
     self.ny = self.lat_2d.shape[0]
     self.nz = self.z.size
Beispiel #2
0
    def __init__(self, path, x_is_lon=True, max_lon=None):

        if path.endswith('.nc'):
            use_netcdf = True
        elif os.path.isdir(path):
            use_netcdf = False
            path = real_dir(path)
            from MITgcmutils import rdmds
        else:
            print 'Error (Grid): ' + path + ' is neither a NetCDF file nor a directory'
            sys.exit()

        # Read variables
        # Note that some variables are capitalised differently in NetCDF versus binary, so can't make this more efficient...
        if use_netcdf:
            self.lon_2d = read_netcdf(path, 'XC')
            self.lat_2d = read_netcdf(path, 'YC')
            self.lon_corners_2d = read_netcdf(path, 'XG')
            self.lat_corners_2d = read_netcdf(path, 'YG')
            self.dx_s = read_netcdf(path, 'dxG')
            self.dy_w = read_netcdf(path, 'dyG')
            # I have no idea why this requires .data but it does, otherwise WSS breaks (?!?!)
            self.dA = read_netcdf(path, 'rA').data
            self.z = read_netcdf(path, 'Z')
            self.z_edges = read_netcdf(path, 'Zp1')
            self.dz = read_netcdf(path, 'drF')
            self.dz_t = read_netcdf(path, 'drC')
            self.hfac = read_netcdf(path, 'hFacC')
            self.hfac_w = read_netcdf(path, 'hFacW')
            self.hfac_s = read_netcdf(path, 'hFacS')
        else:
            self.lon_2d = rdmds(path + 'XC')
            self.lat_2d = rdmds(path + 'YC')
            self.lon_corners_2d = rdmds(path + 'XG')
            self.lat_corners_2d = rdmds(path + 'YG')
            self.dx_s = rdmds(path + 'DXG')
            self.dy_w = rdmds(path + 'DYG')
            self.dA = rdmds(path + 'RAC')
            # Remove singleton dimensions from 1D depth variables
            self.z = rdmds(path + 'RC').squeeze()
            self.z_edges = rdmds(path + 'RF').squeeze()
            self.dz = rdmds(path + 'DRF').squeeze()
            self.dz_t = rdmds(path + 'DRC').squeeze()
            self.hfac = rdmds(path + 'hFacC')
            self.hfac_w = rdmds(path + 'hFacW')
            self.hfac_s = rdmds(path + 'hFacS')

        # Make 1D versions of latitude and longitude arrays (only useful for regular lat-lon grids)
        if len(self.lon_2d.shape) == 2:
            self.lon_1d = self.lon_2d[0, :]
            self.lat_1d = self.lat_2d[:, 0]
            self.lon_corners_1d = self.lon_corners_2d[0, :]
            self.lat_corners_1d = self.lat_corners_2d[:, 0]
        elif len(self.lon_2d.shape) == 1:
            # xmitgcm output has these variables as 1D already. So make 2D ones.
            self.lon_1d = np.copy(self.lon_2d)
            self.lat_1d = np.copy(self.lat_2d)
            self.lon_corners_1d = np.copy(self.lon_corners_2d)
            self.lat_corners_1d = np.copy(self.lat_corners_2d)
            self.lon_2d, self.lat_2d = np.meshgrid(self.lon_1d, self.lat_1d)
            self.lon_corners_2d, self.lat_corners_2d = np.meshgrid(
                self.lon_corners_1d, self.lat_corners_1d)

        # Decide on longitude range
        if max_lon is None and x_is_lon:
            # Choose range automatically
            if np.amin(self.lon_1d) < 180 and np.amax(self.lon_1d) > 180:
                # Domain crosses 180E, so use the range (0, 360)
                max_lon = 360
            else:
                # Use the range (-180, 180)
                max_lon = 180
            # Do one array to test
            self.lon_1d = fix_lon_range(self.lon_1d, max_lon=max_lon)
            # Make sure it's strictly increasing now
            if not np.all(np.diff(self.lon_1d) > 0):
                print 'Error (Grid): Longitude is not strictly increasing either in the range (0, 360) or (-180, 180).'
                sys.exit()
        if max_lon == 360:
            self.split = 0
        elif max_lon == 180:
            self.split = 180
        self.lon_1d = fix_lon_range(self.lon_1d, max_lon=max_lon)
        self.lon_corners_1d = fix_lon_range(self.lon_corners_1d,
                                            max_lon=max_lon)
        self.lon_2d = fix_lon_range(self.lon_2d, max_lon=max_lon)
        self.lon_corners_2d = fix_lon_range(self.lon_corners_2d,
                                            max_lon=max_lon)

        # Save dimensions
        self.nx = self.lon_1d.size
        self.ny = self.lat_1d.size
        self.nz = self.z.size

        # Calculate volume
        self.dV = xy_to_xyz(self.dA, [self.nx, self.ny, self.nz]) * z_to_xyz(
            self.dz, [self.nx, self.ny, self.nz]) * self.hfac

        # Calculate bathymetry and ice shelf draft
        self.bathy = bdry_from_hfac('bathy', self.hfac, self.z_edges)
        self.draft = bdry_from_hfac('draft', self.hfac, self.z_edges)

        # Create masks on the t, u, and v grids
        # Land masks
        self.land_mask = self.build_land_mask(self.hfac)
        self.land_mask_u = self.build_land_mask(self.hfac_w)
        self.land_mask_v = self.build_land_mask(self.hfac_s)
        # Ice shelf masks
        self.ice_mask = self.build_ice_mask(self.hfac)
        self.ice_mask_u = self.build_ice_mask(self.hfac_w)
        self.ice_mask_v = self.build_ice_mask(self.hfac_s)
Beispiel #3
0
    def __init__(self, file_path):

        # 1D lon and lat axes on regular grids
        # Make sure longitude is between -180 and 180
        # Cell centres
        self.lon_1d = fix_lon_range(read_netcdf(file_path, 'X'))
        self.lat_1d = read_netcdf(file_path, 'Y')
        # Cell corners (southwest)
        self.lon_corners_1d = fix_lon_range(read_netcdf(file_path, 'Xp1'))
        self.lat_corners_1d = read_netcdf(file_path, 'Yp1')

        # 2D lon and lat fields on any grid
        # Cell centres
        self.lon_2d = fix_lon_range(read_netcdf(file_path, 'XC'))
        self.lat_2d = read_netcdf(file_path, 'YC')
        # Cell corners
        self.lon_corners_2d = fix_lon_range(read_netcdf(file_path, 'XG'))
        self.lat_corners_2d = read_netcdf(file_path, 'YG')

        # 2D integrands of distance
        # Across faces
        self.dx = read_netcdf(file_path, 'dxF')
        self.dy = read_netcdf(file_path, 'dyF')
        # Between centres
        self.dx_t = read_netcdf(file_path, 'dxC')
        self.dy_t = read_netcdf(file_path, 'dyC')
        # Between u-points
        self.dx_u = self.dx  # Equivalent to distance across face
        self.dy_u = read_netcdf(file_path, 'dyU')
        # Between v-points
        self.dx_v = read_netcdf(file_path, 'dxV')
        self.dy_v = self.dy  # Equivalent to distance across face
        # Between corners
        self.dx_psi = read_netcdf(file_path, 'dxG')
        self.dy_psi = read_netcdf(file_path, 'dyG')

        # 2D integrands of area
        # Area of faces
        self.dA = read_netcdf(file_path, 'rA')
        # Centered on u-points
        self.dA_u = read_netcdf(file_path, 'rAw')
        # Centered on v-points
        self.dA_v = read_netcdf(file_path, 'rAs')
        # Centered on corners
        self.dA_psi = read_netcdf(file_path, 'rAz')

        # Vertical grid
        # Assumes we're in the ocean so using z-levels - not sure how this
        # would handle atmospheric pressure levels.
        # Depth axis at centres of z-levels
        self.z = read_netcdf(file_path, 'Z')
        # Depth axis at edges of z-levels
        self.z_edges = read_netcdf(file_path, 'Zp1')
        # Depth axis at w-points
        self.z_w = read_netcdf(file_path, 'Zl')

        # Vertical integrands of distance
        # Across cells
        self.dz = read_netcdf(file_path, 'drF')
        # Between centres
        self.dz_t = read_netcdf(file_path, 'drC')

        # Dimension lengths (on tracer grid)
        self.nx = self.lon_1d.size
        self.ny = self.lat_1d.size
        self.nz = self.z.size

        # Partial cell fractions
        # At centres
        self.hfac = read_netcdf(file_path, 'HFacC')
        # On western edges
        self.hfac_w = read_netcdf(file_path, 'HFacW')
        # On southern edges
        self.hfac_s = read_netcdf(file_path, 'HFacS')

        # Create masks on the t, u, and v grids
        # We can't do the psi grid because there is no hfac there
        # Land masks
        self.land_mask = build_land_mask(self.hfac)
        self.land_mask_u = build_land_mask(self.hfac_w)
        self.land_mask_v = build_land_mask(self.hfac_s)
        # Ice shelf masks
        self.zice_mask = build_zice_mask(self.hfac)
        self.zice_mask_u = build_zice_mask(self.hfac_w)
        self.zice_mask_v = build_zice_mask(self.hfac_s)
        # FRIS masks
        self.fris_mask = build_fris_mask(self.zice_mask, self.lon_2d,
                                         self.lat_2d)
        self.fris_mask_u = build_fris_mask(self.zice_mask_u,
                                           self.lon_corners_2d, self.lat_2d)
        self.fris_mask_v = build_fris_mask(self.zice_mask_v, self.lon_2d,
                                           self.lat_corners_2d)

        # Topography (as seen by the model after adjustment for eg hfacMin - not necessarily equal to what is specified by the user)
        # Bathymetry (bottom depth)
        self.bathy = read_netcdf(file_path, 'R_low')
        # Ice shelf draft (surface depth, enforce 0 in land or open-ocean points)
        self.zice = read_netcdf(file_path, 'Ro_surf')
        self.zice[np.invert(self.zice_mask)] = 0
        # Water column thickness
        self.wct = read_netcdf(file_path, 'Depth')
Beispiel #4
0
    def __init__(self, path, model_grid=None, split=0):

        self.split = split

        if path.endswith('.nc'):
            use_netcdf = True
        elif os.path.isdir(path):
            use_netcdf = False
            path = real_dir(path)
            from MITgcmutils import rdmds
        else:
            print 'Error (SOSEGrid): ' + path + ' is neither a NetCDF file nor a directory'
            sys.exit()

        self.trim_extend = True
        if model_grid is None:
            self.trim_extend = False

        if self.trim_extend:
            # Error checking for which longitude range we're in
            if split == 180:
                max_lon = 180
                if np.amax(model_grid.lon_2d) > max_lon:
                    print 'Error (SOSEGrid): split=180 does not match model grid'
                    sys.exit()
            elif split == 0:
                max_lon = 360
                if np.amin(model_grid.lon_2d) < 0:
                    print 'Error (SOSEGrid): split=0 does not match model grid'
                    sys.exit()
            else:
                print 'Error (SOSEGrid): split must be 180 or 0'
                sys.exit()
        else:
            max_lon = 360

        # Read variables
        if use_netcdf:
            # Make the 2D grid 1D so it's regular
            self.lon_1d = read_netcdf(path, 'XC')[0, :]
            self.lon_corners_1d = read_netcdf(path, 'XG')[0, :]
            self.lat_1d = read_netcdf(path, 'YC')[:, 0]
            self.lat_corners_1d = read_netcdf(path, 'YG')[:, 0]
            self.z = read_netcdf(path, 'Z')
            self.z_edges = read_netcdf(path, 'RF')
        else:
            self.lon_1d = rdmds(path + 'XC')[0, :]
            self.lon_corners_1d = rdmds(path + 'XG')[0, :]
            self.lat_1d = rdmds(path + 'YC')[:, 0]
            self.lat_corners_1d = rdmds(path + 'YG')[:, 0]
            self.z = rdmds(path + 'RC').squeeze()
            self.z_edges = rdmds(path + 'RF').squeeze()

        # Fix longitude range
        self.lon_1d = fix_lon_range(self.lon_1d, max_lon=max_lon)
        self.lon_corners_1d = fix_lon_range(self.lon_corners_1d,
                                            max_lon=max_lon)
        if split == 180:
            # Split the domain at 180E=180W and rearrange the two halves so longitude is strictly ascending
            self.i_split = np.nonzero(self.lon_1d < 0)[0][0]
        else:
            # Set i_split to 0 which won't actually do anything
            self.i_split = 0
        self.lon_1d = split_longitude(self.lon_1d, self.i_split)
        self.lon_corners_1d = split_longitude(self.lon_corners_1d,
                                              self.i_split)
        if self.lon_corners_1d[0] > 0:
            # The split happened between lon_corners[i_split] and lon[i_split].
            # Take mod 360 on this index of lon_corners to make sure it's strictly increasing.
            self.lon_corners_1d[0] -= 360
        # Make sure the longitude axes are strictly increasing after the splitting
        if not np.all(np.diff(self.lon_1d) > 0) or not np.all(
                np.diff(self.lon_corners_1d) > 0):
            print 'Error (SOSEGrid): longitude is not strictly increasing'
            sys.exit()

        # Save original dimensions
        sose_nx = self.lon_1d.size
        sose_ny = self.lat_1d.size
        sose_nz = self.z.size

        if self.trim_extend:

            # Trim and/or extend the axes
            # Notes about this:
            # Longitude can only be trimmed as SOSE considers all longitudes (someone doing a high-resolution circumpolar model with points in the gap might need to write a patch to wrap the SOSE grid around)
            # Latitude can be trimmed in both directions, or extended to the south (not extended to the north - if you need to do this, SOSE is not the right product for you!)
            # Depth can be extended by one level in both directions, and the deeper bound can also be trimmed
            # The indices i, j, and k will be kept track of with 4 variables each. For example, with longitude:
            # i0_before = first index we care about
            #           = how many cells to trim at beginning
            # i0_after = i0_before's position in the new grid
            #          = how many cells to extend at beginning
            # i1_before = first index we don't care about
            #           sose_nx - i1_before = how many cells to trim at end
            # i1_after = i1_before's position in the new grid
            #          = i1_before - i0_before + i0_after
            # nx = length of new grid
            #      nx - i1_after = how many cells to extend at end

            # Find bounds on model grid
            xmin = np.amin(model_grid.lon_corners_2d)
            xmax = np.amax(model_grid.lon_2d)
            ymin = np.amin(model_grid.lat_corners_2d)
            ymax = np.amax(model_grid.lat_2d)
            z_shallow = model_grid.z[0]
            z_deep = model_grid.z[-1]

            # Western bound (use longitude at cell centres to make sure all grid types clear the bound)
            if xmin == self.lon_1d[0]:
                # Nothing to do
                self.i0_before = 0
            elif xmin > self.lon_1d[0]:
                # Trim
                self.i0_before = np.nonzero(self.lon_1d > xmin)[0][0] - 1
            else:
                print 'Error (SOSEGrid): not allowed to extend westward'
                sys.exit()
            self.i0_after = 0

            # Eastern bound (use longitude at cell corners, i.e. western edge)
            if xmax == self.lon_corners_1d[-1]:
                # Nothing to do
                self.i1_before = sose_nx
            elif xmax < self.lon_corners_1d[-1]:
                # Trim
                self.i1_before = np.nonzero(
                    self.lon_corners_1d > xmax)[0][0] + 1
            else:
                print 'Error (SOSEGrid): not allowed to extend eastward'
                sys.exit()
            self.i1_after = self.i1_before - self.i0_before + self.i0_after
            self.nx = self.i1_after

            # Southern bound (use latitude at cell centres)
            if ymin == self.lat_1d[0]:
                # Nothing to do
                self.j0_before = 0
                self.j0_after = 0
            elif ymin > self.lat_1d[0]:
                # Trim
                self.j0_before = np.nonzero(self.lat_1d > ymin)[0][0] - 1
                self.j0_after = 0
            elif ymin < self.lat_1d[0]:
                # Extend
                self.j0_after = int(np.ceil(
                    (self.lat_1d[0] - ymin) / sose_res))
                self.j0_before = 0

            # Northern bound (use latitude at cell corners, i.e. southern edge)
            if ymax == self.lat_corners_1d[-1]:
                # Nothing to do
                self.j1_before = sose_ny
            elif ymax < self.lat_corners_1d[-1]:
                # Trim
                self.j1_before = np.nonzero(
                    self.lat_corners_1d > ymax)[0][0] + 1
            else:
                print 'Error (SOSEGrid): not allowed to extend northward'
                sys.exit()
            self.j1_after = self.j1_before - self.j0_before + self.j0_after
            self.ny = self.j1_after

            # Depth
            self.k0_before = 0
            if z_shallow <= self.z[0]:
                # Nothing to do
                self.k0_after = 0
            else:
                # Extend
                self.k0_after = 1
            if z_deep > self.z[-1]:
                # Trim
                self.k1_before = np.nonzero(self.z < z_deep)[0][0] + 1
            else:
                # Either extend or do nothing
                self.k1_before = sose_nz
            self.k1_after = self.k1_before + self.k0_after
            if z_deep < self.z[-1]:
                # Extend
                self.nz = self.k1_after + 1
            else:
                self.nz = self.k1_after

            # Now we have the indices we need, so trim/extend the axes as needed
            # Longitude: can only trim
            self.lon_1d = self.lon_1d[self.i0_before:self.i1_before]
            self.lon_corners_1d = self.lon_corners_1d[self.i0_before:self.
                                                      i1_before]
            # Latitude: can extend on south side, trim on both sides
            lat_extend = np.flipud(-1 *
                                   (np.arange(self.j0_after) + 1) * sose_res +
                                   self.lat_1d[self.j0_before])
            lat_trim = self.lat_1d[self.j0_before:self.j1_before]
            self.lat_1d = np.concatenate((lat_extend, lat_trim))
            lat_corners_extend = np.flipud(-1 *
                                           (np.arange(self.j0_after) + 1) *
                                           sose_res +
                                           self.lat_corners_1d[self.j0_before])
            lat_corners_trim = self.lat_corners_1d[self.j0_before:self.
                                                   j1_before]
            self.lat_corners_1d = np.concatenate(
                (lat_corners_extend, lat_corners_trim))
            # Depth: can extend on both sides (depth 0 at top and extrapolated at bottom to clear the deepest model depth), trim on deep side
            z_above = 0 * np.ones([self.k0_after
                                   ])  # Will either be [0] or empty
            z_middle = self.z[self.k0_before:self.k1_before]
            z_edges_middle = self.z_edges[self.k0_before:self.k1_before + 1]
            z_below = (2 * model_grid.z[-1] - model_grid.z[-2]) * np.ones([
                self.nz - self.k1_after
            ])  # Will either be [something deeper than z_deep] or empty
            self.z = np.concatenate((z_above, z_middle, z_below))
            self.z_edges = np.concatenate((z_above, z_edges_middle, z_below))

            # Make sure we cleared those bounds
            if self.lon_corners_1d[0] > xmin:
                print 'Error (SOSEGrid): western bound not cleared'
                sys.exit()
            if self.lon_corners_1d[-1] < xmax:
                print 'Error (SOSEGrid): eastern bound not cleared'
                sys.exit()
            if self.lat_corners_1d[0] > ymin:
                print 'Error (SOSEGrid): southern bound not cleared'
                sys.exit()
            if self.lat_corners_1d[-1] < ymax:
                print 'Error (SOSEGrid): northern bound not cleared'
                sys.exit()
            if self.z[0] < z_shallow:
                print 'Error (SOSEGrid): shallow bound not cleared'
                sys.exit()
            if self.z[-1] > z_deep:
                print 'Error (SOSEGrid): deep bound not cleared'
                sys.exit()

        else:

            # Nothing fancy to do
            self.nx = sose_nx
            self.ny = sose_ny
            self.nz = sose_nz

        # Now read the rest of the variables we need, splitting/trimming/extending them if needed
        if use_netcdf:
            self.hfac = self.read_field(path,
                                        'xyz',
                                        var_name='hFacC',
                                        fill_value=0)
            self.hfac_w = self.read_field(path,
                                          'xyz',
                                          var_name='hFacW',
                                          fill_value=0)
            self.hfac_s = self.read_field(path,
                                          'xyz',
                                          var_name='hFacS',
                                          fill_value=0)
            self.dA = self.read_field(path, 'xy', var_name='rA', fill_value=0)
            self.dz = self.read_field(path, 'z', var_name='DRF', fill_value=0)
        else:
            self.hfac = self.read_field(path + 'hFacC', 'xyz', fill_value=0)
            self.hfac_w = self.read_field(path + 'hFacW', 'xyz', fill_value=0)
            self.hfac_s = self.read_field(path + 'hFacS', 'xyz', fill_value=0)
            self.dA = self.read_field(path + 'RAC', 'xy', fill_value=0)
            self.dz = self.read_field(path + 'DRF', 'z', fill_value=0)
        # Calculate volume
        self.dV = xy_to_xyz(self.dA, [self.nx, self.ny, self.nz]) * z_to_xyz(
            self.dz, [self.nx, self.ny, self.nz]) * self.hfac

        # Mesh lat and lon
        self.lon_2d, self.lat_2d = np.meshgrid(self.lon_1d, self.lat_1d)
        self.lon_corners_2d, self.lat_corners_2d = np.meshgrid(
            self.lon_corners_1d, self.lat_corners_1d)

        # Calculate bathymetry
        self.bathy = bdry_from_hfac('bathy', self.hfac, self.z_edges)

        # Create land masks
        self.land_mask = self.build_land_mask(self.hfac)
        self.land_mask_u = self.build_land_mask(self.hfac_w)
        self.land_mask_v = self.build_land_mask(self.hfac_s)
        # Dummy ice mask with all False
        self.ice_mask = np.zeros(self.land_mask.shape).astype(bool)
Beispiel #5
0
def process_forcing_for_correction(source,
                                   var,
                                   mit_grid_dir,
                                   out_file,
                                   in_dir=None,
                                   start_year=1979,
                                   end_year=None):

    # Set parameters based on source dataset
    if source == 'ERA5':
        if in_dir is None:
            # Path on BAS servers
            in_dir = '/data/oceans_input/processed_input_data/ERA5/'
        file_head = 'ERA5_'
        gtype = ['t', 't', 't', 't', 't']
    elif source == 'UKESM':
        if in_dir is None:
            # Path on JASMIN
            in_dir = '/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/'
        expt = 'historical'
        ensemble_member = 'r1i1p1f2'
        if var == 'wind':
            var_names_in = ['uas', 'vas']
            gtype = ['u', 'v']
        elif var == 'thermo':
            var_names_in = ['tas', 'huss', 'pr', 'ssrd', 'strd']
            gtype = ['t', 't', 't', 't', 't']
        days_per_year = 12 * 30
    elif source == 'PACE':
        if in_dir is None:
            # Path on BAS servers
            in_dir = '/data/oceans_input/processed_input_data/CESM/PACE_new/'
        file_head = 'PACE_ens'
        num_ens = 20
        missing_ens = 13
        if var == 'wind':
            var_names_in = ['UBOT', 'VBOT']
            monthly = [False, False]
        elif var == 'thermo':
            var_names_in = ['TREFHT', 'QBOT', 'PRECT', 'FSDS', 'FLDS']
            monthly = [False, False, False, True, True]
        gtype = ['t', 't', 't', 't', 't']
    else:
        print 'Error (process_forcing_for_correction): invalid source ' + source
        sys.exit()
    # Set parameters based on variable type
    if var == 'wind':
        var_names = ['uwind', 'vwind']
        units = ['m/s', 'm/s']
    elif var == 'thermo':
        var_names = ['atemp', 'aqh', 'precip', 'swdown', 'lwdown']
        units = ['degC', '1', 'm/s', 'W/m^2', 'W/m^2']
    else:
        print 'Error (process_forcing_for_correction): invalid var ' + var
        sys.exit()
    # Check end_year is defined
    if end_year is None:
        print 'Error (process_forcing_for_correction): must set end_year. Typically use 2014 for WSFRIS and 2013 for PACE.'
        sys.exit()

    mit_grid_dir = real_dir(mit_grid_dir)
    in_dir = real_dir(in_dir)

    print 'Building grids'
    if source == 'ERA5':
        forcing_grid = ERA5Grid()
    elif source == 'UKESM':
        forcing_grid = UKESMGrid()
    elif source == 'PACE':
        forcing_grid = PACEGrid()
    mit_grid = Grid(mit_grid_dir)

    ncfile = NCfile(out_file, mit_grid, 'xy')

    # Loop over variables
    for n in range(len(var_names)):
        print 'Processing variable ' + var_names[n]
        # Read the data, time-integrating as we go
        data = None
        num_time = 0

        if source == 'ERA5':
            # Loop over years
            for year in range(start_year, end_year + 1):
                file_path = in_dir + file_head + var_names[n] + '_' + str(year)
                data_tmp = read_binary(file_path,
                                       [forcing_grid.nx, forcing_grid.ny],
                                       'xyt')
                if data is None:
                    data = np.sum(data_tmp, axis=0)
                else:
                    data += np.sum(data_tmp, axis=0)
                num_time += data_tmp.shape[0]

        elif source == ' UKESM':
            in_files, start_years, end_years = find_cmip6_files(
                in_dir, expt, ensemble_member, var_names_in[n], 'day')
            # Loop over each file
            for t in range(len(in_files)):
                file_path = in_files[t]
                print 'Processing ' + file_path
                print 'Covers years ' + str(start_years[t]) + ' to ' + str(
                    end_years[t])
                # Loop over years
                t_start = 0  # Time index in file
                t_end = t_start + days_per_year
                for year in range(start_years[t], end_years[t] + 1):
                    if year >= start_year and year <= end_year:
                        print 'Processing ' + str(year)
                        # Read data
                        print 'Reading ' + str(year) + ' from indices ' + str(
                            t_start) + '-' + str(t_end)
                        data_tmp = read_netcdf(file_path,
                                               var_names_in[n],
                                               t_start=t_start,
                                               t_end=t_end)
                        if data is None:
                            data = np.sum(data_tmp, axis=0)
                        else:
                            data += np.sum(data_tmp, axis=0)
                        num_time += days_per_year
                    # Update time range for next time
                    t_start = t_end
                    t_end = t_start + days_per_year
            if var_names[n] == 'atemp':
                # Convert from K to C
                data -= temp_C2K
            elif var_names[n] == 'precip':
                # Convert from kg/m^2/s to m/s
                data /= rho_fw
            elif var_names[n] in ['swdown', 'lwdown']:
                # Swap sign on radiation fluxes
                data *= -1

        elif source == 'PACE':
            # Loop over years
            for year in range(start_year, end_year + 1):
                # Loop over ensemble members
                data_tmp = None
                num_ens_tmp = 0
                for ens in range(1, num_ens + 1):
                    if ens == missing_ens:
                        continue
                    file_path = in_dir + file_head + str(ens).zfill(
                        2) + '_' + var_names_in[n] + '_' + str(year)
                    data_tmp_ens = read_binary(
                        file_path, [forcing_grid.nx, forcing_grid.ny], 'xyt')
                    if data_tmp is None:
                        data_tmp = data_tmp_ens
                    else:
                        data_tmp += data_tmp_ens
                    num_ens_tmp += 1
                # Ensemble mean for this year
                data_tmp /= num_ens_tmp
                # Now accumulate time integral
                if monthly[n]:
                    # Weighting for different number of days per month
                    for month in range(data_tmp.shape[0]):
                        # Get number of days per month with no leap years
                        ndays = days_per_month(month + 1, 1979)
                        data_tmp[month, :] *= ndays
                        num_time += ndays
                else:
                    num_time += data_tmp.shape[0]
                if data is None:
                    data = np.sum(data_tmp, axis=0)
                else:
                    data += np.sum(data_tmp, axis=0)

        # Now convert from time-integral to time-average
        data /= num_time

        forcing_lon, forcing_lat = forcing_grid.get_lon_lat(gtype=gtype[n],
                                                            dim=1)
        # Get longitude in the range -180 to 180, then split and rearrange so it's monotonically increasing
        forcing_lon = fix_lon_range(forcing_lon)
        i_split = np.nonzero(forcing_lon < 0)[0][0]
        forcing_lon = split_longitude(forcing_lon, i_split)
        data = split_longitude(data, i_split)
        # Now interpolate to MITgcm tracer grid
        mit_lon, mit_lat = mit_grid.get_lon_lat(gtype='t', dim=1)
        print 'Interpolating'
        data_interp = interp_reg_xy(forcing_lon, forcing_lat, data, mit_lon,
                                    mit_lat)
        print 'Saving to ' + out_file
        ncfile.add_variable(var_names[n], data_interp, 'xy', units=units[n])

    ncfile.close()