Пример #1
0
    def run(self):
        """
        Run this step of the test case
        """
        config = self.config
        logger = self.logger

        section = config['gotm']
        nx = section.getint('nx')
        ny = section.getint('ny')
        dc = section.getfloat('dc')

        dsMesh = make_planar_hex_mesh(nx=nx, ny=ny, dc=dc, nonperiodic_x=False,
                                      nonperiodic_y=False)
        write_netcdf(dsMesh, 'grid.nc')

        dsMesh = cull(dsMesh, logger=logger)
        dsMesh = convert(dsMesh, graphInfoFileName='graph.info',
                         logger=logger)
        write_netcdf(dsMesh, 'mesh.nc')

        replacements = dict()
        replacements['config_periodic_planar_vert_levels'] = \
            config.get('gotm', 'vert_levels')
        replacements['config_periodic_planar_bottom_depth'] = \
            config.get('gotm', 'bottom_depth')
        self.update_namelist_at_runtime(options=replacements)

        run_model(self)
Пример #2
0
    def run(self):
        """
        Run this step of the test case
       """
        logger = self.logger
        section = self.config['enthalpy_benchmark']
        nx = section.getint('nx')
        ny = section.getint('ny')
        dc = section.getfloat('dc')
        levels = section.get('levels')

        dsMesh = make_planar_hex_mesh(nx=nx, ny=ny, dc=dc, nonperiodic_x=True,
                                      nonperiodic_y=True)

        write_netcdf(dsMesh, 'grid.nc')

        dsMesh = cull(dsMesh, logger=logger)
        dsMesh = convert(dsMesh, logger=logger)
        write_netcdf(dsMesh, 'mpas_grid.nc')

        args = ['create_landice_grid_from_generic_MPAS_grid.py',
                '-i', 'mpas_grid.nc',
                '-o', 'landice_grid.nc',
                '-l', levels,
                '--thermal']

        check_call(args, logger)

        make_graph_file(mesh_filename='landice_grid.nc',
                        graph_filename='graph.info')

        _setup_initial_conditions(section, 'landice_grid.nc')
Пример #3
0
    def run(self):
        """
        Run this step of the test case
       """
        logger = self.logger
        section = self.config['eismint2']

        nx = section.getint('nx')
        ny = section.getint('ny')
        dc = section.getfloat('dc')

        dsMesh = make_planar_hex_mesh(nx=nx,
                                      ny=ny,
                                      dc=dc,
                                      nonperiodic_x=False,
                                      nonperiodic_y=False)

        dsMesh = convert(dsMesh, logger=logger)
        write_netcdf(dsMesh, 'mpas_grid.nc')
        dsMesh.close()

        radius = section.get('radius')
        args = [
            'define_cullMask.py', '-f', 'mpas_grid.nc', '-m', 'radius', '-d',
            radius
        ]

        check_call(args, logger)

        dsMesh = xarray.open_dataset('mpas_grid.nc')
        dsMesh = cull(dsMesh, logger=logger)
        dsMesh = convert(dsMesh, logger=logger)
        write_netcdf(dsMesh, 'mpas_grid2.nc')

        levels = section.get('levels')
        args = [
            'create_landice_grid_from_generic_MPAS_grid.py', '-i',
            'mpas_grid2.nc', '-o', 'landice_grid.nc', '-l', levels,
            '--thermal', '--beta'
        ]

        check_call(args, logger)

        make_graph_file(mesh_filename='landice_grid.nc',
                        graph_filename='graph.info')
Пример #4
0
    def run(self):
        """
        Run this step of the test case
       """
        mesh_type = self.mesh_type
        logger = self.logger
        config = self.config
        section = config['dome']

        if mesh_type == '2000m':
            nx = section.getint('nx')
            ny = section.getint('ny')
            dc = section.getfloat('dc')

            dsMesh = make_planar_hex_mesh(nx=nx,
                                          ny=ny,
                                          dc=dc,
                                          nonperiodic_x=True,
                                          nonperiodic_y=True)

            write_netcdf(dsMesh, 'grid.nc')

            dsMesh = cull(dsMesh, logger=logger)
            dsMesh = convert(dsMesh, logger=logger)
            write_netcdf(dsMesh, 'mpas_grid.nc')

        levels = section.get('levels')
        args = [
            'create_landice_grid_from_generic_MPAS_grid.py', '-i',
            'mpas_grid.nc', '-o', 'landice_grid.nc', '-l', levels
        ]

        check_call(args, logger)

        make_graph_file(mesh_filename='landice_grid.nc',
                        graph_filename='graph.info')

        _setup_dome_initial_conditions(config,
                                       logger,
                                       filename='landice_grid.nc')
Пример #5
0
    def run(self):
        """
        Run this step of the test case
        """
        config = self.config
        logger = self.logger

        section = config['ziso']
        nx = section.getint('nx')
        ny = section.getint('ny')
        dc = section.getfloat('dc')

        dsMesh = make_planar_hex_mesh(nx=nx, ny=ny, dc=dc, nonperiodic_x=False,
                                      nonperiodic_y=True)
        write_netcdf(dsMesh, 'base_mesh.nc')

        dsMesh = cull(dsMesh, logger=logger)
        dsMesh = convert(dsMesh, graphInfoFileName='culled_graph.info',
                         logger=logger)
        write_netcdf(dsMesh, 'culled_mesh.nc')

        ds = _write_initial_state(config, dsMesh, self.with_frazil)

        _write_forcing(config, ds.yCell, ds.zMid)
Пример #6
0
    def run(self):
        """
        Run this step of the test case
        """
        config = self.config
        resolution = float(self.resolution)

        section = config['planar_convergence']
        nx_1km = section.getint('nx_1km')
        ny_1km = section.getint('ny_1km')
        nx = int(nx_1km / resolution)
        ny = int(ny_1km / resolution)
        dc = resolution * 1e3

        ds_mesh = make_planar_hex_mesh(nx=nx,
                                       ny=ny,
                                       dc=dc,
                                       nonperiodic_x=False,
                                       nonperiodic_y=False)

        center(ds_mesh)

        write_netcdf(ds_mesh, 'mesh.nc')
        make_graph_file('mesh.nc', 'graph.info')
Пример #7
0
    def run(self):
        """
        Run this step of the test case
        """
        config = self.config
        logger = self.logger

        replacements = dict()
        replacements['config_periodic_planar_vert_levels'] = \
            config.getfloat('vertical_grid', 'vert_levels')
        replacements['config_periodic_planar_bottom_depth'] = \
            config.getfloat('vertical_grid', 'bottom_depth')
        self.update_namelist_at_runtime(options=replacements)

        section = config['vertical_grid']
        vert_levels = section.getint('vert_levels')
        bottom_depth = section.getfloat('bottom_depth')

        section = config['internal_wave']
        nx = section.getint('nx')
        ny = section.getint('ny')
        dc = section.getfloat('dc')
        use_distances = section.getboolean('use_distances')
        amplitude_width_dist = section.getfloat('amplitude_width_dist')
        amplitude_width_frac = section.getfloat('amplitude_width_frac')
        bottom_temperature = section.getfloat('bottom_temperature')
        surface_temperature = section.getfloat('surface_temperature')
        temperature_difference = section.getfloat('temperature_difference')
        salinity = section.getfloat('salinity')

        logger.info(' * Make planar hex mesh')
        dsMesh = make_planar_hex_mesh(nx=nx,
                                      ny=ny,
                                      dc=dc,
                                      nonperiodic_x=False,
                                      nonperiodic_y=True)
        logger.info(' * Completed Make planar hex mesh')
        write_netcdf(dsMesh, 'base_mesh.nc')

        logger.info(' * Cull mesh')
        dsMesh = cull(dsMesh, logger=logger)
        logger.info(' * Convert mesh')
        dsMesh = convert(dsMesh,
                         graphInfoFileName='culled_graph.info',
                         logger=logger)
        logger.info(' * Completed Convert mesh')
        write_netcdf(dsMesh, 'culled_mesh.nc')

        ds = dsMesh.copy()
        yCell = ds.yCell

        ds['bottomDepth'] = bottom_depth * xarray.ones_like(yCell)
        ds['ssh'] = xarray.zeros_like(yCell)

        init_vertical_coord(config, ds)

        yMin = yCell.min().values
        yMax = yCell.max().values

        yMid = 0.5 * (yMin + yMax)

        if use_distances:
            perturbation_width = amplitude_width_dist
        else:
            perturbation_width = (yMax - yMin) * amplitude_width_frac

        # Set stratified temperature
        temp_vert = (bottom_temperature +
                     (surface_temperature - bottom_temperature) *
                     ((ds.refZMid + bottom_depth) / bottom_depth))

        depth_frac = xarray.zeros_like(temp_vert)
        refBottomDepth = ds['refBottomDepth']
        for k in range(1, vert_levels):
            depth_frac[k] = refBottomDepth[k -
                                           1] / refBottomDepth[vert_levels - 1]

        # If cell is in the southern half, outside the sin width, subtract
        # temperature difference
        frac = xarray.where(
            numpy.abs(yCell - yMid) < perturbation_width,
            numpy.cos(0.5 * numpy.pi * (yCell - yMid) / perturbation_width) *
            numpy.sin(numpy.pi * depth_frac), 0.)

        temperature = temp_vert - temperature_difference * frac
        temperature = temperature.transpose('nCells', 'nVertLevels')
        temperature = temperature.expand_dims(dim='Time', axis=0)

        normalVelocity = xarray.zeros_like(ds.xEdge)
        normalVelocity, _ = xarray.broadcast(normalVelocity, ds.refBottomDepth)
        normalVelocity = normalVelocity.transpose('nEdges', 'nVertLevels')
        normalVelocity = normalVelocity.expand_dims(dim='Time', axis=0)

        ds['temperature'] = temperature
        ds['salinity'] = salinity * xarray.ones_like(temperature)
        ds['normalVelocity'] = normalVelocity

        write_netcdf(ds, 'ocean.nc')
Пример #8
0
    def run(self):
        """
        Run this step of the test case
        """
        config = self.config
        logger = self.logger

        section = config['ice_shelf_2d']
        nx = section.getint('nx')
        ny = section.getint('ny')
        dc = section.getfloat('dc')

        dsMesh = make_planar_hex_mesh(nx=nx, ny=ny, dc=dc, nonperiodic_x=False,
                                      nonperiodic_y=True)
        write_netcdf(dsMesh, 'base_mesh.nc')

        dsMesh = cull(dsMesh, logger=logger)
        dsMesh = convert(dsMesh, graphInfoFileName='culled_graph.info',
                         logger=logger)
        write_netcdf(dsMesh, 'culled_mesh.nc')

        bottom_depth = config.getfloat('vertical_grid', 'bottom_depth')

        section = config['ice_shelf_2d']
        temperature = section.getfloat('temperature')
        surface_salinity = section.getfloat('surface_salinity')
        bottom_salinity = section.getfloat('bottom_salinity')

        # points 1 and 2 are where angles on ice shelf are located.
        # point 3 is at the surface.
        # d variables are total water-column thickness below ice shelf
        y1 = section.getfloat('y1')
        y2 = section.getfloat('y2')
        y3 = y2 + section.getfloat('edge_width')
        d1 = section.getfloat('cavity_thickness')
        d2 = d1 + section.getfloat('slope_height')
        d3 = bottom_depth

        ds = dsMesh.copy()

        ds['bottomDepth'] = bottom_depth * xarray.ones_like(ds.xCell)

        yCell = ds.yCell

        column_thickness = xarray.where(
            yCell < y1, d1, d1 + (d2 - d1) * (yCell - y1) / (y2 - y1))
        column_thickness = xarray.where(
            yCell < y2, column_thickness,
            d2 + (d3 - d2) * (yCell - y2) / (y3 - y2))
        column_thickness = xarray.where(yCell < y3, column_thickness, d3)

        ds['ssh'] = -bottom_depth + column_thickness

        # set up the vertical coordinate
        init_vertical_coord(config, ds)

        modify_mask = xarray.where(yCell < y3, 1, 0).expand_dims(
            dim='Time', axis=0)
        landIceFraction = modify_mask.astype(float)
        landIceMask = modify_mask.copy()

        ref_density = constants['SHR_CONST_RHOSW']
        landIcePressure, landIceDraft = compute_land_ice_pressure_and_draft(
            ssh=ds.ssh, modify_mask=modify_mask, ref_density=ref_density)

        salinity = surface_salinity + ((bottom_salinity - surface_salinity) *
                                       (ds.zMid / (-bottom_depth)))
        salinity, _ = xarray.broadcast(salinity, ds.layerThickness)
        salinity = salinity.transpose('Time', 'nCells', 'nVertLevels')

        normalVelocity = xarray.zeros_like(ds.xEdge)
        normalVelocity, _ = xarray.broadcast(normalVelocity, ds.refBottomDepth)
        normalVelocity = normalVelocity.transpose('nEdges', 'nVertLevels')
        normalVelocity = normalVelocity.expand_dims(dim='Time', axis=0)

        ds['temperature'] = temperature * xarray.ones_like(ds.layerThickness)
        ds['salinity'] = salinity
        ds['normalVelocity'] = normalVelocity
        ds['fCell'] = xarray.zeros_like(ds.xCell)
        ds['fEdge'] = xarray.zeros_like(ds.xEdge)
        ds['fVertex'] = xarray.zeros_like(ds.xVertex)
        ds['modifyLandIcePressureMask'] = modify_mask
        ds['landIceFraction'] = landIceFraction
        ds['landIceMask'] = landIceMask
        ds['landIcePressure'] = landIcePressure
        ds['landIceDraft'] = landIceDraft

        write_netcdf(ds, 'initial_state.nc')
Пример #9
0
    def run(self):
        """
        Run this step of the test case
        """
        config = self.config
        logger = self.logger

        section = config['baroclinic_channel']
        nx = section.getint('nx')
        ny = section.getint('ny')
        dc = section.getfloat('dc')

        dsMesh = make_planar_hex_mesh(nx=nx,
                                      ny=ny,
                                      dc=dc,
                                      nonperiodic_x=False,
                                      nonperiodic_y=True)
        write_netcdf(dsMesh, 'base_mesh.nc')

        dsMesh = cull(dsMesh, logger=logger)
        dsMesh = convert(dsMesh,
                         graphInfoFileName='culled_graph.info',
                         logger=logger)
        write_netcdf(dsMesh, 'culled_mesh.nc')

        section = config['baroclinic_channel']
        use_distances = section.getboolean('use_distances')
        gradient_width_dist = section.getfloat('gradient_width_dist')
        gradient_width_frac = section.getfloat('gradient_width_frac')
        bottom_temperature = section.getfloat('bottom_temperature')
        surface_temperature = section.getfloat('surface_temperature')
        temperature_difference = section.getfloat('temperature_difference')
        salinity = section.getfloat('salinity')
        coriolis_parameter = section.getfloat('coriolis_parameter')

        ds = dsMesh.copy()
        xCell = ds.xCell
        yCell = ds.yCell

        bottom_depth = config.getfloat('vertical_grid', 'bottom_depth')

        ds['bottomDepth'] = bottom_depth * xarray.ones_like(xCell)
        ds['ssh'] = xarray.zeros_like(xCell)

        init_vertical_coord(config, ds)

        xMin = xCell.min().values
        xMax = xCell.max().values
        yMin = yCell.min().values
        yMax = yCell.max().values

        yMid = 0.5 * (yMin + yMax)
        xPerturbMin = xMin + 4.0 * (xMax - xMin) / 6.0
        xPerturbMax = xMin + 5.0 * (xMax - xMin) / 6.0

        if use_distances:
            perturbationWidth = gradient_width_dist
        else:
            perturbationWidth = (yMax - yMin) * gradient_width_frac

        yOffset = perturbationWidth * numpy.sin(6.0 * numpy.pi *
                                                (xCell - xMin) / (xMax - xMin))

        temp_vert = (bottom_temperature +
                     (surface_temperature - bottom_temperature) *
                     ((ds.refZMid + bottom_depth) / bottom_depth))

        frac = xarray.where(yCell < yMid - yOffset, 1., 0.)

        mask = numpy.logical_and(yCell >= yMid - yOffset,
                                 yCell < yMid - yOffset + perturbationWidth)
        frac = xarray.where(
            mask, 1. - (yCell - (yMid - yOffset)) / perturbationWidth, frac)

        temperature = temp_vert - temperature_difference * frac
        temperature = temperature.transpose('nCells', 'nVertLevels')

        # Determine yOffset for 3rd crest in sin wave
        yOffset = 0.5 * perturbationWidth * numpy.sin(
            numpy.pi * (xCell - xPerturbMin) / (xPerturbMax - xPerturbMin))

        mask = numpy.logical_and(
            numpy.logical_and(
                yCell >= yMid - yOffset - 0.5 * perturbationWidth,
                yCell <= yMid - yOffset + 0.5 * perturbationWidth),
            numpy.logical_and(xCell >= xPerturbMin, xCell <= xPerturbMax))

        temperature = (temperature + mask * 0.3 *
                       (1. - ((yCell - (yMid - yOffset)) /
                              (0.5 * perturbationWidth))))

        temperature = temperature.expand_dims(dim='Time', axis=0)

        normalVelocity = xarray.zeros_like(ds.xEdge)
        normalVelocity, _ = xarray.broadcast(normalVelocity, ds.refBottomDepth)
        normalVelocity = normalVelocity.transpose('nEdges', 'nVertLevels')
        normalVelocity = normalVelocity.expand_dims(dim='Time', axis=0)

        ds['temperature'] = temperature
        ds['salinity'] = salinity * xarray.ones_like(temperature)
        ds['normalVelocity'] = normalVelocity
        ds['fCell'] = coriolis_parameter * xarray.ones_like(xCell)
        ds['fEdge'] = coriolis_parameter * xarray.ones_like(ds.xEdge)
        ds['fVertex'] = coriolis_parameter * xarray.ones_like(ds.xVertex)

        write_netcdf(ds, 'ocean.nc')
Пример #10
0
    def run(self):
        """
        Run this step of the test case
        """
        config = self.config
        logger = self.logger

        section = config['isomip_plus']
        nx = section.getint('nx')
        ny = section.getint('ny')
        dc = section.getfloat('dc')
        filter_sigma = section.getfloat('topo_smoothing') * self.resolution
        min_ice_thickness = section.getfloat('min_ice_thickness')
        min_land_ice_fraction = section.getfloat('min_land_ice_fraction')
        draft_scaling = section.getfloat('draft_scaling')

        process_input_geometry('input_geometry.nc',
                               'input_geometry_processed.nc',
                               filterSigma=filter_sigma,
                               minIceThickness=min_ice_thickness,
                               scale=draft_scaling)

        dsMesh = make_planar_hex_mesh(nx=nx + 2,
                                      ny=ny + 2,
                                      dc=dc,
                                      nonperiodic_x=False,
                                      nonperiodic_y=False)
        translate(mesh=dsMesh, yOffset=-2 * dc)
        write_netcdf(dsMesh, 'base_mesh.nc')

        dsGeom = xarray.open_dataset('input_geometry_processed.nc')

        min_ocean_fraction = config.getfloat('isomip_plus',
                                             'min_ocean_fraction')

        dsMask = interpolate_ocean_mask(dsMesh, dsGeom, min_ocean_fraction)
        dsMesh = cull(dsMesh, dsInverse=dsMask, logger=logger)
        dsMesh.attrs['is_periodic'] = 'NO'

        dsMesh = convert(dsMesh,
                         graphInfoFileName='culled_graph.info',
                         logger=logger)
        write_netcdf(dsMesh, 'culled_mesh.nc')

        ds = interpolate_geom(dsMesh, dsGeom, min_ocean_fraction)

        for var in ['landIceFraction']:
            ds[var] = ds[var].expand_dims(dim='Time', axis=0)

        ds['landIceMask'] = \
            (ds.landIceFraction >= min_land_ice_fraction).astype(int)

        ref_density = constants['SHR_CONST_RHOSW']
        landIcePressure, landIceDraft = compute_land_ice_pressure_and_draft(
            ssh=ds.ssh, modify_mask=ds.ssh < 0., ref_density=ref_density)

        ds['landIcePressure'] = landIcePressure
        ds['landIceDraft'] = landIceDraft

        if self.time_varying_forcing:
            self._write_time_varying_forcing(ds_init=ds)

        ds['bottomDepth'] = -ds.bottomDepthObserved

        section = config['isomip_plus']

        min_column_thickness = section.getfloat('min_column_thickness')
        min_levels = section.getint('minimum_levels')

        interfaces = generate_1d_grid(config)

        # Deepen the bottom depth to maintain the minimum water-column
        # thickness
        min_depth = numpy.maximum(-ds.ssh + min_column_thickness,
                                  interfaces[min_levels + 1])
        ds['bottomDepth'] = numpy.maximum(ds.bottomDepth, min_depth)

        init_vertical_coord(config, ds)

        ds['modifyLandIcePressureMask'] = \
            (ds['landIceFraction'] > 0.01).astype(int)

        max_bottom_depth = -config.getfloat('vertical_grid', 'bottom_depth')
        frac = (0. - ds.zMid) / (0. - max_bottom_depth)

        # compute T, S
        init_top_temp = section.getfloat('init_top_temp')
        init_bot_temp = section.getfloat('init_bot_temp')
        init_top_sal = section.getfloat('init_top_sal')
        init_bot_sal = section.getfloat('init_bot_sal')
        ds['temperature'] = (1.0 - frac) * init_top_temp + frac * init_bot_temp
        ds['salinity'] = (1.0 - frac) * init_top_sal + frac * init_bot_sal

        # compute coriolis
        coriolis_parameter = section.getfloat('coriolis_parameter')

        ds['fCell'] = coriolis_parameter * xarray.ones_like(ds.xCell)
        ds['fEdge'] = coriolis_parameter * xarray.ones_like(ds.xEdge)
        ds['fVertex'] = coriolis_parameter * xarray.ones_like(ds.xVertex)

        normalVelocity = xarray.zeros_like(ds.xEdge)
        normalVelocity = normalVelocity.broadcast_like(ds.refBottomDepth)
        normalVelocity = normalVelocity.transpose('nEdges', 'nVertLevels')
        ds['normalVelocity'] = normalVelocity.expand_dims(dim='Time', axis=0)

        write_netcdf(ds, 'initial_state.nc')

        plot_folder = '{}/plots'.format(self.work_dir)
        if os.path.exists(plot_folder):
            shutil.rmtree(plot_folder)

        # plot a few fields
        section_y = config.getfloat('isomip_plus_viz', 'section_y')

        # show progress only if we're not writing to a log file
        show_progress = self.log_filename is None

        plotter = MoviePlotter(inFolder=self.work_dir,
                               streamfunctionFolder=self.work_dir,
                               outFolder=plot_folder,
                               expt=self.experiment,
                               sectionY=section_y,
                               dsMesh=ds,
                               ds=ds,
                               showProgress=show_progress)

        plotter.plot_3d_field_top_bot_section(ds.zMid,
                                              nameInTitle='zMid',
                                              prefix='zmid',
                                              units='m',
                                              vmin=-720.,
                                              vmax=0.,
                                              cmap='cmo.deep_r')

        plotter.plot_3d_field_top_bot_section(ds.temperature,
                                              nameInTitle='temperature',
                                              prefix='temp',
                                              units='C',
                                              vmin=-2.,
                                              vmax=1.,
                                              cmap='cmo.thermal')

        plotter.plot_3d_field_top_bot_section(ds.salinity,
                                              nameInTitle='salinity',
                                              prefix='salin',
                                              units='PSU',
                                              vmin=33.8,
                                              vmax=34.7,
                                              cmap='cmo.haline')

        # compute restoring
        dsForcing = xarray.Dataset()

        restore_top_temp = section.getfloat('restore_top_temp')
        restore_bot_temp = section.getfloat('restore_bot_temp')
        restore_top_sal = section.getfloat('restore_top_sal')
        restore_bot_sal = section.getfloat('restore_bot_sal')
        dsForcing['temperatureInteriorRestoringValue'] = \
            (1.0 - frac) * restore_top_temp + frac * restore_bot_temp
        dsForcing['salinityInteriorRestoringValue'] = \
            (1.0 - frac) * restore_top_sal + frac * restore_bot_sal

        restore_rate = section.getfloat('restore_rate')
        restore_xmin = section.getfloat('restore_xmin')
        restore_xmax = section.getfloat('restore_xmax')
        frac = numpy.maximum(
            (ds.xCell - restore_xmin) / (restore_xmax - restore_xmin), 0.)
        frac = frac.broadcast_like(dsForcing.temperatureInteriorRestoringValue)

        # convert from 1/days to 1/s
        dsForcing['temperatureInteriorRestoringRate'] = \
            frac * restore_rate / constants['SHR_CONST_CDAY']
        dsForcing['salinityInteriorRestoringRate'] = \
            dsForcing.temperatureInteriorRestoringRate

        # compute "evaporation"
        restore_evap_rate = section.getfloat('restore_evap_rate')

        mask = numpy.logical_and(ds.xCell >= restore_xmin,
                                 ds.xCell <= restore_xmax)
        mask = mask.expand_dims(dim='Time', axis=0)
        # convert to m/s, negative for evaporation rather than precipitation
        evap_rate = -restore_evap_rate / (constants['SHR_CONST_CDAY'] * 365)
        # PSU*m/s to kg/m^2/s
        sflux_factor = 1.
        # C*m/s to W/m^2
        hflux_factor = 1. / (ref_density * constants['SHR_CONST_CPSW'])
        dsForcing['evaporationFlux'] = mask * ref_density * evap_rate
        dsForcing['seaIceSalinityFlux'] = \
            mask*evap_rate*restore_top_sal/sflux_factor
        dsForcing['seaIceHeatFlux'] = \
            mask*evap_rate*restore_top_temp/hflux_factor

        write_netcdf(dsForcing, 'init_mode_forcing_data.nc')