Example #1
0
def iceberg_meltwater(grid_path, input_dir, output_file, nc_out=None, prec=32):

    from plot_latlon import latlon_plot

    input_dir = real_dir(input_dir)
    file_head = 'icebergs_'
    file_tail = '.nc'

    print 'Building grids'
    # Read the NEMO grid from the first file
    # It has longitude in the range -180 to 180
    file_path = input_dir + file_head + '01' + file_tail
    nemo_lon = read_netcdf(file_path, 'nav_lon')
    nemo_lat = read_netcdf(file_path, 'nav_lat')
    # Build the model grid
    model_grid = Grid(grid_path, max_lon=180)

    print 'Interpolating'
    icebergs_interp = np.zeros([12, model_grid.ny, model_grid.nx])
    for month in range(12):
        print '...month ' + str(month + 1)
        # Read the data
        file_path = input_dir + file_head + '{0:02d}'.format(month +
                                                             1) + file_tail
        icebergs = read_netcdf(file_path, 'berg_total_melt', time_index=0)
        # Interpolate
        icebergs_interp_tmp = interp_nonreg_xy(nemo_lon,
                                               nemo_lat,
                                               icebergs,
                                               model_grid.lon_1d,
                                               model_grid.lat_1d,
                                               fill_value=0)
        # Make sure the land and ice shelf cavities don't get any iceberg melt
        icebergs_interp_tmp[model_grid.land_mask + model_grid.ice_mask] = 0
        # Save to the master array
        icebergs_interp[month, :] = icebergs_interp_tmp

    write_binary(icebergs_interp, output_file, prec=prec)

    print 'Plotting'
    # Make a nice plot of the annual mean
    latlon_plot(mask_land_ice(np.mean(icebergs_interp, axis=0), model_grid),
                model_grid,
                include_shelf=False,
                vmin=0,
                title=r'Annual mean iceberg melt (kg/m$^2$/s)')
    if nc_out is not None:
        # Also write to NetCDF file
        print 'Writing ' + nc_out
        ncfile = NCfile(nc_out, model_grid, 'xyt')
        ncfile.add_time(np.arange(12) + 1, units='months')
        ncfile.add_variable('iceberg_melt',
                            icebergs_interp,
                            'xyt',
                            units='kg/m^2/s')
        ncfile.close()
Example #2
0
def calc_ice_prod (file_path, out_file, monthly=True):

    # Build the grid from the file
    grid = Grid(file_path)

    # Add up all the terms to get sea ice production at each time index
    ice_prod = read_netcdf(file_path, 'SIdHbOCN') + read_netcdf(file_path, 'SIdHbATC') + read_netcdf(file_path, 'SIdHbATO') + read_netcdf(file_path, 'SIdHbFLO')
    # Also need time
    time = netcdf_time(file_path, monthly=monthly)

    # Set negative values to 0
    ice_prod = np.maximum(ice_prod, 0)

    # Write a new file
    ncfile = NCfile(out_file, grid, 'xyt')
    ncfile.add_time(time)
    ncfile.add_variable('ice_prod', ice_prod, 'xyt', long_name='Net sea ice production', units='m/s')
    ncfile.close()
Example #3
0
def sose_ics (grid_path, sose_dir, output_dir, nc_out=None, constant_t=-1.9, constant_s=34.4, split=180, prec=64):

    from grid import SOSEGrid
    from file_io import NCfile
    from interpolation import interp_reg

    sose_dir = real_dir(sose_dir)
    output_dir = real_dir(output_dir)

    # Fields to interpolate
    fields = ['THETA', 'SALT', 'SIarea', 'SIheff']
    # Flag for 2D or 3D
    dim = [3, 3, 2, 2]
    # Constant values for ice shelf cavities
    constant_value = [constant_t, constant_s, 0, 0]
    # End of filenames for input
    infile_tail = '_climatology.data'
    # End of filenames for output
    outfile_tail = '_SOSE.ini'
    
    print 'Building grids'
    # First build the model grid and check that we have the right value for split
    model_grid = grid_check_split(grid_path, split)
    # Now build the SOSE grid
    sose_grid = SOSEGrid(sose_dir+'grid/', model_grid=model_grid, split=split)
    # Extract land mask
    sose_mask = sose_grid.hfac == 0
    
    print 'Building mask for SOSE points to fill'
    # Figure out which points we need for interpolation
    # Find open cells according to the model, interpolated to SOSE grid
    model_open = np.ceil(interp_reg(model_grid, sose_grid, np.ceil(model_grid.hfac), fill_value=1))
    # Find ice shelf cavity points according to model, interpolated to SOSE grid
    model_cavity = np.ceil(interp_reg(model_grid, sose_grid, xy_to_xyz(model_grid.ice_mask, model_grid), fill_value=0)).astype(bool)
    # Select open, non-cavity cells
    fill = model_open*np.invert(model_cavity)
    # Extend into the mask a few times to make sure there are no artifacts near the coast
    fill = extend_into_mask(fill, missing_val=0, use_3d=True, num_iters=3)

    # Set up a NetCDF file so the user can check the results
    if nc_out is not None:
        ncfile = NCfile(nc_out, model_grid, 'xyz')

    # Process fields
    for n in range(len(fields)):
        print 'Processing ' + fields[n]
        in_file = sose_dir + fields[n] + infile_tail
        out_file = output_dir + fields[n] + outfile_tail
        print '...reading ' + in_file
        # Just keep the January climatology
        if dim[n] == 3:
            sose_data = sose_grid.read_field(in_file, 'xyzt')[0,:]
        else:
            # Fill any missing regions with zero sea ice, as we won't be extrapolating them later
            sose_data = sose_grid.read_field(in_file, 'xyt', fill_value=0)[0,:]
        # Discard the land mask, and extrapolate slightly into missing regions so the interpolation doesn't get messed up.
        print '...extrapolating into missing regions'
        if dim[n] == 3:
            sose_data = discard_and_fill(sose_data, sose_mask, fill)
            # Fill cavity points with constant values
            sose_data[model_cavity] = constant_value[n]
        else:
            # Just care about surface layer
            sose_data = discard_and_fill(sose_data, sose_mask[0,:], fill[0,:], use_3d=False)
        print '...interpolating to model grid'
        data_interp = interp_reg(sose_grid, model_grid, sose_data, dim=dim[n])
        # Fill the land mask with zeros
        if dim[n] == 3:
            data_interp[model_grid.hfac==0] = 0
        else:
            data_interp[model_grid.hfac[0,:]==0] = 0
        write_binary(data_interp, out_file, prec=prec)
        if nc_out is not None:
            print '...adding to ' + nc_out
            if dim[n] == 3:
                ncfile.add_variable(fields[n], data_interp, 'xyz')
            else:
                ncfile.add_variable(fields[n], data_interp, 'xy')

    if nc_out is not None:
        ncfile.close()
Example #4
0
def make_obcs (location, grid_path, input_path, output_dir, source='SOSE', use_seaice=True, nc_out=None, prec=32, split=180):

    from grid import SOSEGrid
    from file_io import NCfile, read_netcdf
    from interpolation import interp_bdry

    if source == 'SOSE':
        input_path = real_dir(input_path)
    output_dir = real_dir(output_dir)

    # Fields to interpolate
    # Important: SIarea has to be before SIuice and SIvice so it can be used for masking
    fields = ['THETA', 'SALT', 'UVEL', 'VVEL', 'SIarea', 'SIheff', 'SIuice', 'SIvice', 'ETAN']  
    # Flag for 2D or 3D
    dim = [3, 3, 3, 3, 2, 2, 2, 2, 2]
    # Flag for grid type
    gtype = ['t', 't', 'u', 'v', 't', 't', 'u', 'v', 't']
    if source == 'MIT':
        # Also consider snow thickness
        fields += ['SIhsnow']
        dim += [2]
        gtype += ['t']
    # End of filenames for input
    infile_tail = '_climatology.data'
    # End of filenames for output
    outfile_tail = '_'+source+'.OBCS_'+location

    print 'Building MITgcm grid'
    if source == 'SOSE':
        model_grid = grid_check_split(grid_path, split)
    elif source == 'MIT':
        model_grid = Grid(grid_path)
    # Figure out what the latitude or longitude is on the boundary, both on the centres and outside edges of those cells
    if location == 'S':
        lat0 = model_grid.lat_1d[0]
        lat0_e = model_grid.lat_corners_1d[0]
        print 'Southern boundary at ' + str(lat0) + ' (cell centre), ' + str(lat0_e) + ' (cell edge)'
    elif location == 'N':
        lat0 = model_grid.lat_1d[-1]
        lat0_e = 2*model_grid.lat_corners_1d[-1] - model_grid.lat_corners_1d[-2]
        print 'Northern boundary at ' + str(lat0) + ' (cell centre), ' + str(lat0_e) + ' (cell edge)'
    elif location == 'W':
        lon0 = model_grid.lon_1d[0]
        lon0_e = model_grid.lon_corners_1d[0]
        print 'Western boundary at ' + str(lon0) + ' (cell centre), ' + str(lon0_e) + ' (cell edge)'
    elif location == 'E':
        lon0 = model_grid.lon_1d[-1]
        lon0_e = 2*model_grid.lon_corners_1d[-1] - model_grid.lon_corners_1d[-2]
        print 'Eastern boundary at ' + str(lon0) + ' (cell centre), ' + str(lon0_e) + ' (cell edge)'
    else:
        print 'Error (make_obcs): invalid location ' + str(location)
        sys.exit()

    if source == 'SOSE':
        print 'Building SOSE grid'
        source_grid = SOSEGrid(input_path+'grid/', model_grid=model_grid, split=split)
    elif source == 'MIT':
        print 'Building grid from source model'
        source_grid = Grid(input_path)
    else:
        print 'Error (make_obcs): invalid source ' + source
        sys.exit()
    # Calculate interpolation indices and coefficients to the boundary latitude or longitude
    if location in ['N', 'S']:
        # Cell centre
        j1, j2, c1, c2 = interp_slice_helper(source_grid.lat_1d, lat0)
        # Cell edge
        j1_e, j2_e, c1_e, c2_e = interp_slice_helper(source_grid.lat_corners_1d, lat0_e)
    else:
        # Pass lon=True to consider the possibility of boundary near 0E
        i1, i2, c1, c2 = interp_slice_helper(source_grid.lon_1d, lon0, lon=True)
        i1_e, i2_e, c1_e, c2_e = interp_slice_helper(source_grid.lon_corners_1d, lon0_e, lon=True)

    # Set up a NetCDF file so the user can check the results
    if nc_out is not None:
        ncfile = NCfile(nc_out, model_grid, 'xyzt')
        ncfile.add_time(np.arange(12)+1, units='months')  

    # Process fields
    for n in range(len(fields)):
        if fields[n].startswith('SI') and not use_seaice:
            continue

        print 'Processing ' + fields[n]
        if source == 'SOSE':
            in_file = input_path + fields[n] + infile_tail
        out_file = output_dir + fields[n] + outfile_tail
        # Read the monthly climatology at all points
        if source == 'SOSE':
            if dim[n] == 3:
                source_data = source_grid.read_field(in_file, 'xyzt')
            else:
                source_data = source_grid.read_field(in_file, 'xyt')
        else:
            source_data = read_netcdf(input_path, fields[n])

        if fields[n] == 'SIarea' and source == 'SOSE':
            # We'll need this field later for SIuice and SIvice, as SOSE didn't mask those variables properly
            print 'Interpolating sea ice area to u and v grids for masking of sea ice velocity'
            source_aice_u = interp_grid(source_data, source_grid, 't', 'u', time_dependent=True, mask_with_zeros=True, periodic=True)
            source_aice_v = interp_grid(source_data, source_grid, 't', 'v', time_dependent=True, mask_with_zeros=True, periodic=True)
        # Set sea ice velocity to zero wherever sea ice area is zero
        if fields[n] in ['SIuice', 'SIvice'] and source == 'SOSE':
            print 'Masking sea ice velocity with sea ice area'
            if fields[n] == 'SIuice':
                index = source_aice_u==0
            else:
                index = source_aice_v==0
            source_data[index] = 0            

        # Choose the correct grid for lat, lon, hfac
        source_lon, source_lat = source_grid.get_lon_lat(gtype=gtype[n], dim=1)
        source_hfac = source_grid.get_hfac(gtype=gtype[n])
        model_lon, model_lat = model_grid.get_lon_lat(gtype=gtype[n], dim=1)
        model_hfac = model_grid.get_hfac(gtype=gtype[n])
        # Interpolate to the correct grid and choose the correct horizontal axis
        if location in ['N', 'S']:
            if gtype[n] == 'v':
                source_data = c1_e*source_data[...,j1_e,:] + c2_e*source_data[...,j2_e,:]
                # Multiply hfac by the ceiling of hfac on each side, to make sure we're not averaging over land
                source_hfac = (c1_e*source_hfac[...,j1_e,:] + c2_e*source_hfac[...,j2_e,:])*np.ceil(source_hfac[...,j1_e,:])*np.ceil(source_hfac[...,j2_e,:])
            else:
                source_data = c1*source_data[...,j1,:] + c2*source_data[...,j2,:]
                source_hfac = (c1*source_hfac[...,j1,:] + c2*source_hfac[...,j2,:])*np.ceil(source_hfac[...,j1,:])*np.ceil(source_hfac[...,j2,:])
            source_haxis = source_lon
            model_haxis = model_lon
            if location == 'S':
                model_hfac = model_hfac[:,0,:]
            else:
                model_hfac = model_hfac[:,-1,:]
        else:
            if gtype[n] == 'u':
                source_data = c1_e*source_data[...,i1_e] + c2_e*source_data[...,i2_e]
                source_hfac = (c1_e*source_hfac[...,i1_e] + c2_e*source_hfac[...,i2_e])*np.ceil(source_hfac[...,i1_e])*np.ceil(source_hfac[...,i2_e])
            else:
                source_data = c1*source_data[...,i1] + c2*source_data[...,i2]
                source_hfac = (c1*source_hfac[...,i1] + c2*source_hfac[...,i2])*np.ceil(source_hfac[...,i1])*np.ceil(source_hfac[...,i2])
            source_haxis = source_lat
            model_haxis = model_lat
            if location == 'W':
                model_hfac = model_hfac[...,0]
            else:
                model_hfac = model_hfac[...,-1]
        if source == 'MIT' and model_haxis[0] < source_haxis[0]:
            # Need to extend source data to the west or south. Just add one row.
            source_haxis = np.concatenate(([model_haxis[0]-0.1], source_haxis))
            source_data = np.concatenate((np.expand_dims(source_data[:,...,0], -1), source_data), axis=-1)
            source_hfac = np.concatenate((np.expand_dims(source_hfac[:,0], 1), source_hfac), axis=1)
        # For 2D variables, just need surface hfac
        if dim[n] == 2:
            source_hfac = source_hfac[0,:]
            model_hfac = model_hfac[0,:]

        # Now interpolate each month to the model grid
        if dim[n] == 3:
            data_interp = np.zeros([12, model_grid.nz, model_haxis.size])
        else:
            data_interp = np.zeros([12, model_haxis.size])
        for month in range(12):
            print '...interpolating month ' + str(month+1)
            data_interp_tmp = interp_bdry(source_haxis, source_grid.z, source_data[month,:], source_hfac, model_haxis, model_grid.z, model_hfac, depth_dependent=(dim[n]==3))
            if fields[n] not in ['THETA', 'SALT']:
                # Zero in land mask is more physical than extrapolated data
                index = model_hfac==0
                data_interp_tmp[index] = 0
            data_interp[month,:] = data_interp_tmp

        write_binary(data_interp, out_file, prec=prec)
        
        if nc_out is not None:
            print '...adding to ' + nc_out
            # Construct the dimension code
            if location in ['S', 'N']:
                dimension = 'x'
            else:
                dimension = 'y'
            if dim[n] == 3:
                dimension += 'z'
            dimension += 't'
            ncfile.add_variable(fields[n] + '_' + location, data_interp, dimension)

    if nc_out is not None:
        ncfile.close()
Example #5
0
def crash_to_netcdf(crash_dir, grid_path):

    # Make sure crash_dir is a proper directory
    if not crash_dir.endswith('/'):
        crash_dir += '/'

    # Read the grid
    grid = Grid(grid_path)
    # Initialise the NetCDF file
    ncfile = NCfile(crash_dir + 'crash.nc', grid, 'xyz')

    # Find all the crash files
    for file in os.listdir(crash_dir):
        if file.startswith('stateThetacrash') and file.endswith('.data'):
            # Found temperature
            # Read it from binary
            temp = read_binary(crash_dir + file, grid, 'xyz')
            # Write it to NetCDF
            ncfile.add_variable('THETA', temp, 'xyz', units='C')
        if file.startswith('stateSaltcrash') and file.endswith('.data'):
            salt = read_binary(crash_dir + file, grid, 'xyz')
            ncfile.add_variable('SALT', salt, 'xyz', units='psu')
        if file.startswith('stateUvelcrash') and file.endswith('.data'):
            u = read_binary(crash_dir + file, grid, 'xyz')
            ncfile.add_variable('UVEL', u, 'xyz', gtype='u', units='m/s')
        if file.startswith('stateVvelcrash') and file.endswith('.data'):
            v = read_binary(crash_dir + file, grid, 'xyz')
            ncfile.add_variable('VVEL', v, 'xyz', gtype='v', units='m/s')
        if file.startswith('stateWvelcrash') and file.endswith('.data'):
            w = read_binary(crash_dir + file, grid, 'xyz')
            ncfile.add_variable('WVEL', w, 'xyz', gtype='w', units='m/s')
        if file.startswith('stateEtacrash') and file.endswith('.data'):
            eta = read_binary(crash_dir + file, grid, 'xy')
            ncfile.add_variable('ETAN', eta, 'xy', units='m')
        if file.startswith('stateAreacrash') and file.endswith('.data'):
            area = read_binary(crash_dir + file, grid, 'xy')
            ncfile.add_variable('SIarea', area, 'xy', units='fraction')
        if file.startswith('stateHeffcrash') and file.endswith('.data'):
            heff = read_binary(crash_dir + file, grid, 'xy')
            ncfile.add_variable('SIheff', heff, 'xy', units='m')
        if file.startswith('stateUicecrash') and file.endswith('.data'):
            uice = read_binary(crash_dir + file, grid, 'xy')
            ncfile.add_variable('SIuice', uice, 'xy', gtype='u', units='m/s')
        if file.startswith('stateVicecrash') and file.endswith('.data'):
            vice = read_binary(crash_dir + file, grid, 'xy')
            ncfile.add_variable('SIvice', vice, 'xy', gtype='v', units='m/s')
        if file.startswith('stateQnetcrash') and file.endswith('.data'):
            qnet = read_binary(crash_dir + file, grid, 'xy')
            ncfile.add_variable('Qnet', qnet, 'xy', units='W/m^2')
        if file.startswith('stateMxlcrash') and file.endswith('.data'):
            mld = read_binary(crash_dir + file, grid, 'xy')
            ncfile.add_variable('MXLDEPTH', mld, 'xy', units='m')
        if file.startswith('stateEmpmrcrash') and file.endswith('.data'):
            empmr = read_binary(crash_dir + file, grid, 'xy')
            ncfile.add_variable('Empmr', empmr, 'xy', units='kg/m^2/s')

    ncfile.finished()
Example #6
0
def sose_sss_restoring (grid_path, sose_dir, output_salt_file, output_mask_file, nc_out=None, h0=-1250, obcs_sponge=0, split=180, prec=64):

    sose_dir = real_dir(sose_dir)

    print 'Building grids'
    # First build the model grid and check that we have the right value for split
    model_grid = grid_check_split(grid_path, split)
    # Now build the SOSE grid
    sose_grid = SOSEGrid(sose_dir+'grid/', model_grid=model_grid, split=split)
    # Extract surface land mask
    sose_mask = sose_grid.hfac[0,:] == 0

    print 'Building mask'
    mask_surface = np.ones([model_grid.ny, model_grid.nx])
    # Mask out land and ice shelves
    mask_surface[model_grid.hfac[0,:]==0] = 0
    # Save this for later
    mask_land_ice = np.copy(mask_surface)
    # Mask out continental shelf
    mask_surface[model_grid.bathy > h0] = 0
    # Smooth, and remask the land and ice shelves
    mask_surface = smooth_xy(mask_surface, sigma=2)*mask_land_ice
    if obcs_sponge > 0:
        # Also mask the cells affected by OBCS and/or its sponge
        mask_surface[:obcs_sponge,:] = 0
        mask_surface[-obcs_sponge:,:] = 0
        mask_surface[:,:obcs_sponge] = 0
        mask_surface[:,-obcs_sponge:] = 0
    # Make a 3D version with zeros in deeper layers
    mask_3d = np.zeros([model_grid.nz, model_grid.ny, model_grid.nx])
    mask_3d[0,:] = mask_surface
    
    print 'Reading SOSE salinity'
    # Just keep the surface layer
    sose_sss = sose_grid.read_field(sose_dir+'SALT_climatology.data', 'xyzt')[:,0,:,:]
    
    # Figure out which SOSE points we need for interpolation
    # Restoring mask interpolated to the SOSE grid
    fill = np.ceil(interp_reg(model_grid, sose_grid, mask_3d[0,:], dim=2, fill_value=1))
    # Extend into the mask a few times to make sure there are no artifacts near the coast
    fill = extend_into_mask(fill, missing_val=0, num_iters=3)

    # Process one month at a time
    sss_interp = np.zeros([12, model_grid.nz, model_grid.ny, model_grid.nx])
    for month in range(12):
        print 'Month ' + str(month+1)
        print '...filling missing values'
        sose_sss_filled = discard_and_fill(sose_sss[month,:], sose_mask, fill, use_3d=False)
        print '...interpolating'
        # Mask out land and ice shelves
        sss_interp[month,0,:] = interp_reg(sose_grid, model_grid, sose_sss_filled, dim=2)*mask_land_ice

    write_binary(sss_interp, output_salt_file, prec=prec)
    write_binary(mask_3d, output_mask_file, prec=prec)

    if nc_out is not None:
        print 'Writing ' + nc_out
        ncfile = NCfile(nc_out, model_grid, 'xyzt')
        ncfile.add_time(np.arange(12)+1, units='months')
        ncfile.add_variable('salinity', sss_interp, 'xyzt', units='psu')
        ncfile.add_variable('restoring_mask', mask_3d, 'xyz')
        ncfile.close()
Example #7
0
def precompute_timeseries (mit_file, timeseries_file, timeseries_types=None, monthly=True, lon0=None, lat0=None):

    # Timeseries to compute
    if timeseries_types is None:
        timeseries_types = ['fris_mass_balance', 'eta_avg', 'seaice_area', 'fris_temp', 'fris_salt', 'fris_age'] #['fris_mass_balance', 'hice_corner', 'mld_ewed', 'eta_avg', 'seaice_area', 'fris_temp', 'fris_salt']

    # Build the grid
    grid = Grid(mit_file)

    # Check if the timeseries file already exists
    file_exists = os.path.isfile(timeseries_file)
    if file_exists:
        # Open it
        id = nc.Dataset(timeseries_file, 'a')
    else:
        # Create it
        ncfile = NCfile(timeseries_file, grid, 't')

    # Define/update time
    # Read the time array from the MITgcm file, and its units
    time, time_units = netcdf_time(mit_file, return_units=True)
    if file_exists:
        # Update the units to match the old time array
        time_units = id.variables['time'].units
        # Also figure out how many time indices are in the file so far
        num_time = id.variables['time'].size
        # Convert to numeric values
        time = nc.date2num(time, time_units)
        # Append to file
        id.variables['time'][num_time:] = time
    else:
        # Add the time variable to the file
        ncfile.add_time(time, units=time_units)

    # Inner function to define/update non-time variables
    def write_var (data, var_name, title, units):
        if file_exists:
            # Append to file
            id.variables[var_name][num_time:] = data
        else:
            # Add the variable to the file
            ncfile.add_variable(var_name, data, 't', long_name=title, units=units)

    # Now process all the timeseries
    for ts_name in timeseries_types:
        print 'Processing ' + ts_name
        # Get information about the variable; only care about title and units
        title, units = set_parameters(ts_name)[2:4]
        if ts_name == 'fris_mass_balance':
            melt, freeze = calc_special_timeseries(ts_name, mit_file, grid=grid, monthly=monthly)[1:]
            # We need two titles now
            title_melt = 'Total melting beneath FRIS'
            title_freeze = 'Total refreezing beneath FRIS'
            # Update two variables
            write_var(melt, 'fris_total_melt', title_melt, units)
            write_var(freeze, 'fris_total_freeze', title_freeze, units)
        else:
            data = calc_special_timeseries(ts_name, mit_file, grid=grid, lon0=lon0, lat0=lat0, monthly=monthly)[1]
            write_var(data, ts_name, title, units)

    # Finished
    if file_exists:
        id.close()
    else:
        ncfile.close()
Example #8
0
def process_forcing_for_correction(source,
                                   var,
                                   mit_grid_dir,
                                   out_file,
                                   in_dir=None,
                                   start_year=1979,
                                   end_year=None):

    # Set parameters based on source dataset
    if source == 'ERA5':
        if in_dir is None:
            # Path on BAS servers
            in_dir = '/data/oceans_input/processed_input_data/ERA5/'
        file_head = 'ERA5_'
        gtype = ['t', 't', 't', 't', 't']
    elif source == 'UKESM':
        if in_dir is None:
            # Path on JASMIN
            in_dir = '/badc/cmip6/data/CMIP6/CMIP/MOHC/UKESM1-0-LL/'
        expt = 'historical'
        ensemble_member = 'r1i1p1f2'
        if var == 'wind':
            var_names_in = ['uas', 'vas']
            gtype = ['u', 'v']
        elif var == 'thermo':
            var_names_in = ['tas', 'huss', 'pr', 'ssrd', 'strd']
            gtype = ['t', 't', 't', 't', 't']
        days_per_year = 12 * 30
    elif source == 'PACE':
        if in_dir is None:
            # Path on BAS servers
            in_dir = '/data/oceans_input/processed_input_data/CESM/PACE_new/'
        file_head = 'PACE_ens'
        num_ens = 20
        missing_ens = 13
        if var == 'wind':
            var_names_in = ['UBOT', 'VBOT']
            monthly = [False, False]
        elif var == 'thermo':
            var_names_in = ['TREFHT', 'QBOT', 'PRECT', 'FSDS', 'FLDS']
            monthly = [False, False, False, True, True]
        gtype = ['t', 't', 't', 't', 't']
    else:
        print 'Error (process_forcing_for_correction): invalid source ' + source
        sys.exit()
    # Set parameters based on variable type
    if var == 'wind':
        var_names = ['uwind', 'vwind']
        units = ['m/s', 'm/s']
    elif var == 'thermo':
        var_names = ['atemp', 'aqh', 'precip', 'swdown', 'lwdown']
        units = ['degC', '1', 'm/s', 'W/m^2', 'W/m^2']
    else:
        print 'Error (process_forcing_for_correction): invalid var ' + var
        sys.exit()
    # Check end_year is defined
    if end_year is None:
        print 'Error (process_forcing_for_correction): must set end_year. Typically use 2014 for WSFRIS and 2013 for PACE.'
        sys.exit()

    mit_grid_dir = real_dir(mit_grid_dir)
    in_dir = real_dir(in_dir)

    print 'Building grids'
    if source == 'ERA5':
        forcing_grid = ERA5Grid()
    elif source == 'UKESM':
        forcing_grid = UKESMGrid()
    elif source == 'PACE':
        forcing_grid = PACEGrid()
    mit_grid = Grid(mit_grid_dir)

    ncfile = NCfile(out_file, mit_grid, 'xy')

    # Loop over variables
    for n in range(len(var_names)):
        print 'Processing variable ' + var_names[n]
        # Read the data, time-integrating as we go
        data = None
        num_time = 0

        if source == 'ERA5':
            # Loop over years
            for year in range(start_year, end_year + 1):
                file_path = in_dir + file_head + var_names[n] + '_' + str(year)
                data_tmp = read_binary(file_path,
                                       [forcing_grid.nx, forcing_grid.ny],
                                       'xyt')
                if data is None:
                    data = np.sum(data_tmp, axis=0)
                else:
                    data += np.sum(data_tmp, axis=0)
                num_time += data_tmp.shape[0]

        elif source == ' UKESM':
            in_files, start_years, end_years = find_cmip6_files(
                in_dir, expt, ensemble_member, var_names_in[n], 'day')
            # Loop over each file
            for t in range(len(in_files)):
                file_path = in_files[t]
                print 'Processing ' + file_path
                print 'Covers years ' + str(start_years[t]) + ' to ' + str(
                    end_years[t])
                # Loop over years
                t_start = 0  # Time index in file
                t_end = t_start + days_per_year
                for year in range(start_years[t], end_years[t] + 1):
                    if year >= start_year and year <= end_year:
                        print 'Processing ' + str(year)
                        # Read data
                        print 'Reading ' + str(year) + ' from indices ' + str(
                            t_start) + '-' + str(t_end)
                        data_tmp = read_netcdf(file_path,
                                               var_names_in[n],
                                               t_start=t_start,
                                               t_end=t_end)
                        if data is None:
                            data = np.sum(data_tmp, axis=0)
                        else:
                            data += np.sum(data_tmp, axis=0)
                        num_time += days_per_year
                    # Update time range for next time
                    t_start = t_end
                    t_end = t_start + days_per_year
            if var_names[n] == 'atemp':
                # Convert from K to C
                data -= temp_C2K
            elif var_names[n] == 'precip':
                # Convert from kg/m^2/s to m/s
                data /= rho_fw
            elif var_names[n] in ['swdown', 'lwdown']:
                # Swap sign on radiation fluxes
                data *= -1

        elif source == 'PACE':
            # Loop over years
            for year in range(start_year, end_year + 1):
                # Loop over ensemble members
                data_tmp = None
                num_ens_tmp = 0
                for ens in range(1, num_ens + 1):
                    if ens == missing_ens:
                        continue
                    file_path = in_dir + file_head + str(ens).zfill(
                        2) + '_' + var_names_in[n] + '_' + str(year)
                    data_tmp_ens = read_binary(
                        file_path, [forcing_grid.nx, forcing_grid.ny], 'xyt')
                    if data_tmp is None:
                        data_tmp = data_tmp_ens
                    else:
                        data_tmp += data_tmp_ens
                    num_ens_tmp += 1
                # Ensemble mean for this year
                data_tmp /= num_ens_tmp
                # Now accumulate time integral
                if monthly[n]:
                    # Weighting for different number of days per month
                    for month in range(data_tmp.shape[0]):
                        # Get number of days per month with no leap years
                        ndays = days_per_month(month + 1, 1979)
                        data_tmp[month, :] *= ndays
                        num_time += ndays
                else:
                    num_time += data_tmp.shape[0]
                if data is None:
                    data = np.sum(data_tmp, axis=0)
                else:
                    data += np.sum(data_tmp, axis=0)

        # Now convert from time-integral to time-average
        data /= num_time

        forcing_lon, forcing_lat = forcing_grid.get_lon_lat(gtype=gtype[n],
                                                            dim=1)
        # Get longitude in the range -180 to 180, then split and rearrange so it's monotonically increasing
        forcing_lon = fix_lon_range(forcing_lon)
        i_split = np.nonzero(forcing_lon < 0)[0][0]
        forcing_lon = split_longitude(forcing_lon, i_split)
        data = split_longitude(data, i_split)
        # Now interpolate to MITgcm tracer grid
        mit_lon, mit_lat = mit_grid.get_lon_lat(gtype='t', dim=1)
        print 'Interpolating'
        data_interp = interp_reg_xy(forcing_lon, forcing_lat, data, mit_lon,
                                    mit_lat)
        print 'Saving to ' + out_file
        ncfile.add_variable(var_names[n], data_interp, 'xy', units=units[n])

    ncfile.close()