def set_topography(self, vs):
        with h5netcdf.File(DATA_FILES['topography'], 'r') as topography_file:
            topo_x, topo_y, topo_z = (np.array(topography_file.variables[k],
                                               dtype='float').T
                                      for k in ('x', 'y', 'z'))
        topo_z[topo_z > 0] = 0.

        # smooth topography to match grid resolution
        gaussian_sigma = (0.5 * len(topo_x) / vs.nx, 0.5 * len(topo_y) / vs.ny)
        topo_z_smoothed = scipy.ndimage.gaussian_filter(topo_z,
                                                        sigma=gaussian_sigma)
        topo_z_smoothed[topo_z >= -1] = 0

        topo_x_shifted, topo_z_shifted = self._shift_longitude_array(
            vs, topo_x, topo_z_smoothed)
        coords = (vs.xt[2:-2], vs.yt[2:-2])
        z_interp = allocate(vs, ('xt', 'yt'), local=False)
        z_interp[2:-2, 2:-2] = veros.tools.interpolate(
            (topo_x_shifted, topo_y),
            topo_z_shifted,
            coords,
            kind='nearest',
            fill=False)

        depth_levels = 1 + np.argmin(np.abs(z_interp[:, :, np.newaxis] -
                                            vs.zt[np.newaxis, np.newaxis, :]),
                                     axis=2)
        vs.kbot[2:-2, 2:-2] = np.where(z_interp < 0., depth_levels, 0)[2:-2,
                                                                       2:-2]
        vs.kbot *= vs.kbot < vs.nz

        enforce_boundaries(vs, vs.kbot)

        # remove marginal seas
        # (dilate to close 1-cell passages, fill holes, undo dilation)
        marginal = (scipy.ndimage.binary_erosion(
            scipy.ndimage.binary_fill_holes(
                scipy.ndimage.binary_dilation(vs.kbot == 0))))

        vs.kbot[marginal] = 0
    def set_initial_conditions(self, vs):
        rpart_shortwave = 0.58
        efold1_shortwave = 0.35
        efold2_shortwave = 23.0

        t_grid = (vs.xt[2:-2], vs.yt[2:-2], vs.zt)
        xt_forc, yt_forc, zt_forc = (self._get_data(vs, k)
                                     for k in ('xt', 'yt', 'zt'))
        zt_forc = zt_forc[::-1]

        # coordinates must be monotonous for this to work
        assert np.diff(xt_forc).all() > 0
        assert np.diff(yt_forc).all() > 0

        # determine slice to read from forcing file
        data_subset = (
            slice(
                max(0,
                    int(np.argmax(xt_forc >= vs.xt.min())) - 1),
                len(xt_forc) -
                max(0,
                    int(np.argmax(xt_forc[::-1] <= vs.xt.max())) - 1)),
            slice(
                max(0,
                    int(np.argmax(yt_forc >= vs.yt.min())) - 1),
                len(yt_forc) -
                max(0,
                    int(np.argmax(yt_forc[::-1] <= vs.yt.max())) - 1)),
            Ellipsis)

        xt_forc = xt_forc[data_subset[0]]
        yt_forc = yt_forc[data_subset[1]]

        # initial conditions
        temp_raw = self._get_data(vs, 'temperature',
                                  idx=data_subset)[..., ::-1]
        temp_data = veros.tools.interpolate((xt_forc, yt_forc, zt_forc),
                                            temp_raw, t_grid)
        vs.temp[2:-2, 2:-2, :, 0] = temp_data * vs.maskT[2:-2, 2:-2, :]
        vs.temp[2:-2, 2:-2, :, 1] = temp_data * vs.maskT[2:-2, 2:-2, :]

        salt_raw = self._get_data(vs, 'salinity', idx=data_subset)[..., ::-1]
        salt_data = veros.tools.interpolate((xt_forc, yt_forc, zt_forc),
                                            salt_raw, t_grid)
        vs.salt[2:-2, 2:-2, :, 0] = salt_data * vs.maskT[2:-2, 2:-2, :]
        vs.salt[2:-2, 2:-2, :, 1] = salt_data * vs.maskT[2:-2, 2:-2, :]

        # wind stress on MIT grid
        time_grid = (vs.xt[2:-2], vs.yt[2:-2], np.arange(12))
        taux_raw = self._get_data(vs, 'tau_x', idx=data_subset)
        taux_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                            taux_raw, time_grid)
        vs.taux[2:-2, 2:-2, :] = taux_data / vs.rho_0

        tauy_raw = self._get_data(vs, 'tau_y', idx=data_subset)
        tauy_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                            tauy_raw, time_grid)
        vs.tauy[2:-2, 2:-2, :] = tauy_data / vs.rho_0

        enforce_boundaries(vs, vs.taux)
        enforce_boundaries(vs, vs.tauy)

        # Qnet and dQ/dT and Qsol
        qnet_raw = self._get_data(vs, 'q_net', idx=data_subset)
        qnet_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                            qnet_raw, time_grid)
        vs.qnet[2:-2,
                2:-2, :] = -qnet_data * vs.maskT[2:-2, 2:-2, -1, np.newaxis]

        qnec_raw = self._get_data(vs, 'dqdt', idx=data_subset)
        qnec_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                            qnec_raw, time_grid)
        vs.qnec[2:-2,
                2:-2, :] = qnec_data * vs.maskT[2:-2, 2:-2, -1, np.newaxis]

        qsol_raw = self._get_data(vs, 'swf', idx=data_subset)
        qsol_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                            qsol_raw, time_grid)
        vs.qsol[2:-2,
                2:-2, :] = -qsol_data * vs.maskT[2:-2, 2:-2, -1, np.newaxis]

        # SST and SSS
        sst_raw = self._get_data(vs, 'sst', idx=data_subset)
        sst_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                           sst_raw, time_grid)
        vs.t_star[2:-2,
                  2:-2, :] = sst_data * vs.maskT[2:-2, 2:-2, -1, np.newaxis]

        sss_raw = self._get_data(vs, 'sss', idx=data_subset)
        sss_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                           sss_raw, time_grid)
        vs.s_star[2:-2,
                  2:-2, :] = sss_data * vs.maskT[2:-2, 2:-2, -1, np.newaxis]

        if vs.enable_idemix:
            tidal_energy_raw = self._get_data(vs,
                                              'tidal_energy',
                                              idx=data_subset)
            tidal_energy_data = veros.tools.interpolate(
                (xt_forc, yt_forc), tidal_energy_raw, t_grid[:-1])
            mask_x, mask_y = (i + 2 for i in np.indices((vs.nx, vs.ny)))
            mask_z = np.maximum(0, vs.kbot[2:-2, 2:-2] - 1)
            tidal_energy_data[:, :] *= vs.maskW[mask_x, mask_y,
                                                mask_z] / vs.rho_0
            vs.forc_iw_bottom[2:-2, 2:-2] = tidal_energy_data
        """
        Initialize penetration profile for solar radiation and store divergence in divpen
        note that pen is set to 0.0 at the surface instead of 1.0 to compensate for the
        shortwave part of the total surface flux
        """
        swarg1 = vs.zw / efold1_shortwave
        swarg2 = vs.zw / efold2_shortwave
        pen = rpart_shortwave * np.exp(swarg1) + (
            1.0 - rpart_shortwave) * np.exp(swarg2)
        pen[-1] = 0.
        vs.divpen_shortwave[1:] = (pen[1:] - pen[:-1]) / vs.dzt[1:]
        vs.divpen_shortwave[0] = pen[0] / vs.dzt[0]
Beispiel #3
0
def npzd(vs):
    r"""
    Main driving function for NPZD functionality

    Computes transport terms and biological activity separately

    \begin{equation}
        \dfrac{\partial C_i}{\partial t} = T + S
    \end{equation}
    """
    if not vs.enable_npzd:
        return

    # TODO: Refactor transportation code to be defined only once and also used by thermodynamics
    # TODO: Dissipation on W-grid if necessary

    npzd_changes = biogeochemistry(vs)
    """
    For vertical mixing
    """

    a_tri = allocate(vs, ('xt', 'yt', 'zt'), include_ghosts=False)
    b_tri = allocate(vs, ('xt', 'yt', 'zt'), include_ghosts=False)
    c_tri = allocate(vs, ('xt', 'yt', 'zt'), include_ghosts=False)
    d_tri = allocate(vs, ('xt', 'yt', 'zt'), include_ghosts=False)
    delta = allocate(vs, ('xt', 'yt', 'zt'), include_ghosts=False)

    ks = vs.kbot[2:-2, 2:-2] - 1
    delta[:, :, :-1] = vs.dt_tracer / vs.dzw[np.newaxis, np.newaxis, :-1]\
        * vs.kappaH[2:-2, 2:-2, :-1]
    delta[:, :, -1] = 0
    a_tri[:, :, 1:] = -delta[:, :, :-1] / vs.dzt[np.newaxis, np.newaxis, 1:]
    b_tri[:, :,
          1:] = 1 + (delta[:, :, 1:] +
                     delta[:, :, :-1]) / vs.dzt[np.newaxis, np.newaxis, 1:]
    b_tri_edge = 1 + delta / vs.dzt[np.newaxis, np.newaxis, :]
    c_tri[:, :, :-1] = -delta[:, :, :-1] / vs.dzt[np.newaxis, np.newaxis, :-1]

    for tracer in vs.npzd_transported_tracers:
        tracer_data = vs.npzd_tracers[tracer]
        """
        Advection of tracers
        """
        thermodynamics.advect_tracer(
            vs, tracer_data[:, :, :, vs.tau],
            vs.npzd_advection_derivatives[tracer][:, :, :, vs.tau])

        # Adam-Bashforth timestepping
        tracer_data[:, :, :, vs.taup1] = tracer_data[:, :, :, vs.tau] + vs.dt_tracer \
            * ((1.5 + vs.AB_eps) * vs.npzd_advection_derivatives[tracer][:, :, :, vs.tau]
               - (0.5 + vs.AB_eps) * vs.npzd_advection_derivatives[tracer][:, :, :, vs.taum1])\
            * vs.maskT
        """
        Diffusion of tracers
        """

        if vs.enable_hor_diffusion:
            horizontal_diffusion_change = np.zeros_like(tracer_data[:, :, :,
                                                                    0])
            diffusion.horizontal_diffusion(vs, tracer_data[:, :, :, vs.tau],
                                           horizontal_diffusion_change)

            tracer_data[:, :, :,
                        vs.taup1] += vs.dt_tracer * horizontal_diffusion_change

        if vs.enable_biharmonic_mixing:
            biharmonic_diffusion_change = np.empty_like(tracer_data[:, :, :,
                                                                    0])
            diffusion.biharmonic(vs, tracer_data[:, :, :, vs.tau],
                                 np.sqrt(abs(vs.K_hbi)),
                                 biharmonic_diffusion_change)

            tracer_data[:, :, :,
                        vs.taup1] += vs.dt_tracer * biharmonic_diffusion_change
        """
        Restoring zones
        """
        # TODO add restoring zones to general tracers
        """
        Isopycnal diffusion
        """
        if vs.enable_neutral_diffusion:
            dtracer_iso = np.zeros_like(tracer_data[..., 0])

            isoneutral.isoneutral_diffusion_tracer(vs,
                                                   tracer_data,
                                                   dtracer_iso,
                                                   iso=True,
                                                   skew=False)

            if vs.enable_skew_diffusion:
                dtracer_skew = np.zeros_like(tracer_data[..., 0])
                isoneutral.isoneutral_diffusion_tracer(vs,
                                                       tracer_data,
                                                       dtracer_skew,
                                                       iso=False,
                                                       skew=True)
        """
        Vertical mixing of tracers
        """
        d_tri[:, :, :] = tracer_data[2:-2, 2:-2, :, vs.taup1]
        # TODO: surface flux?
        # d_tri[:, :, -1] += surface_forcing
        sol, mask = utilities.solve_implicit(vs,
                                             ks,
                                             a_tri,
                                             b_tri,
                                             c_tri,
                                             d_tri,
                                             b_edge=b_tri_edge)

        tracer_data[2:-2, 2:-2, :, vs.taup1] = utilities.where(
            vs, mask, sol, tracer_data[2:-2, 2:-2, :, vs.taup1])

    # update by biogeochemical changes
    for tracer, change in npzd_changes.items():
        vs.npzd_tracers[tracer][:, :, :, vs.taup1] += change

    # prepare next timestep with minimum tracer values
    for tracer in vs.npzd_tracers.values():
        tracer[:, :, :, vs.taup1] = np.maximum(tracer[:, :, :, vs.taup1],
                                               vs.trcmin * vs.maskT)

    for tracer in vs.npzd_tracers.values():
        utilities.enforce_boundaries(vs, tracer)
Beispiel #4
0
    def run(self, show_progress_bar=None):
        """Main routine of the simulation.

        Note:
            Make sure to call :meth:`setup` prior to this function.

        Arguments:
            show_progress_bar (:obj:`bool`, optional): Whether to show fancy progress bar via tqdm.
                By default, only show if stdout is a terminal and Veros is running on a single process.

        """
        vs = self.state

        logger.info('\nStarting integration for {0[0]:.1f} {0[1]}'.format(time.format_time(vs.runlen)))

        start_time, start_iteration = vs.time, vs.itt
        profiler = None

        pbar = progress.get_progress_bar(vs, use_tqdm=show_progress_bar)

        with handlers.signals_to_exception():
            try:
                with pbar:
                    while vs.time - start_time < vs.runlen:
                        with vs.timers['diagnostics']:
                            diagnostics.write_restart(vs)

                        if vs.itt - start_iteration == 3 and rs.profile_mode and rst.proc_rank == 0:
                            # when using bohrium, most kernels should be pre-compiled by now
                            profiler = diagnostics.start_profiler()

                        with vs.timers['main']:
                            self.set_forcing(vs)

                            if vs.enable_idemix:
                                idemix.set_idemix_parameter(vs)

                            with vs.timers['eke']:
                                eke.set_eke_diffusivities(vs)

                            with vs.timers['tke']:
                                tke.set_tke_diffusivities(vs)

                            with vs.timers['momentum']:
                                momentum.momentum(vs)

                            with vs.timers['temperature']:
                                thermodynamics.thermodynamics(vs)

                            if vs.enable_eke or vs.enable_tke or vs.enable_idemix:
                                advection.calculate_velocity_on_wgrid(vs)

                            with vs.timers['eke']:
                                if vs.enable_eke:
                                    eke.integrate_eke(vs)

                            with vs.timers['idemix']:
                                if vs.enable_idemix:
                                    idemix.integrate_idemix(vs)

                            with vs.timers['tke']:
                                if vs.enable_tke:
                                    tke.integrate_tke(vs)

                            utilities.enforce_boundaries(vs, vs.u[:, :, :, vs.taup1])
                            utilities.enforce_boundaries(vs, vs.v[:, :, :, vs.taup1])
                            if vs.enable_tke:
                                utilities.enforce_boundaries(vs, vs.tke[:, :, :, vs.taup1])
                            if vs.enable_eke:
                                utilities.enforce_boundaries(vs, vs.eke[:, :, :, vs.taup1])
                            if vs.enable_idemix:
                                utilities.enforce_boundaries(vs, vs.E_iw[:, :, :, vs.taup1])

                            momentum.vertical_velocity(vs)

                        with vs.timers['plugins']:
                            for plugin in self._plugin_interfaces:
                                with vs.timers[plugin.name]:
                                    plugin.run_entrypoint(vs)

                        vs.itt += 1
                        vs.time += vs.dt_tracer
                        pbar.advance_time(vs.dt_tracer)

                        self.after_timestep(vs)

                        with vs.timers['diagnostics']:
                            if not diagnostics.sanity_check(vs):
                                raise RuntimeError('solution diverged at iteration {}'.format(vs.itt))

                            if vs.enable_neutral_diffusion and vs.enable_skew_diffusion:
                                isoneutral.isoneutral_diag_streamfunction(vs)

                            diagnostics.diagnose(vs)
                            diagnostics.output(vs)

                        # NOTE: benchmarks parse this, do not change / remove
                        logger.debug(' Time step took {:.2f}s', vs.timers['main'].get_last_time())

                        # permutate time indices
                        vs.taum1, vs.tau, vs.taup1 = vs.tau, vs.taup1, vs.taum1

            except:
                logger.critical('Stopping integration at iteration {}', vs.itt)
                raise

            else:
                logger.success('Integration done\n')

            finally:
                diagnostics.write_restart(vs, force=True)

                timing_summary = [
                    '',
                    'Timing summary:',
                    ' setup time               = {:.2f}s'.format(vs.timers['setup'].get_time()),
                    ' main loop time           = {:.2f}s'.format(vs.timers['main'].get_time()),
                    '   momentum               = {:.2f}s'.format(vs.timers['momentum'].get_time()),
                    '     pressure             = {:.2f}s'.format(vs.timers['pressure'].get_time()),
                    '     friction             = {:.2f}s'.format(vs.timers['friction'].get_time()),
                    '   thermodynamics         = {:.2f}s'.format(vs.timers['temperature'].get_time()),
                    '     lateral mixing       = {:.2f}s'.format(vs.timers['isoneutral'].get_time()),
                    '     vertical mixing      = {:.2f}s'.format(vs.timers['vmix'].get_time()),
                    '     equation of state    = {:.2f}s'.format(vs.timers['eq_of_state'].get_time()),
                    '   EKE                    = {:.2f}s'.format(vs.timers['eke'].get_time()),
                    '   IDEMIX                 = {:.2f}s'.format(vs.timers['idemix'].get_time()),
                    '   TKE                    = {:.2f}s'.format(vs.timers['tke'].get_time()),
                    ' diagnostics and I/O      = {:.2f}s'.format(vs.timers['diagnostics'].get_time()),
                    ' plugins                  = {:.2f}s'.format(vs.timers['plugins'].get_time()),
                ]

                timing_summary.extend([
                    '   {:<22} = {:.2f}s'.format(plugin.name, vs.timers[plugin.name].get_time())
                    for plugin in vs._plugin_interfaces
                ])

                logger.debug('\n'.join(timing_summary))

                if profiler is not None:
                    diagnostics.stop_profiler(profiler)
Beispiel #5
0
    def set_initial_conditions(self, vs):
        rpart_shortwave = 0.58
        efold1_shortwave = 0.35
        efold2_shortwave = 23.0

        t_grid = (vs.xt[2:-2], vs.yt[2:-2], vs.zt)
        xt_forc, yt_forc, zt_forc = (self._get_data(vs, k) for k in ('xt', 'yt', 'zt'))
        zt_forc = zt_forc[::-1]

        # initial conditions
        temp_data = veros.tools.interpolate((xt_forc, yt_forc, zt_forc), self._get_data(vs, 'temperature')[:, :, ::-1],
                                      t_grid, missing_value=0.)
        vs.temp[2:-2, 2:-2, :, 0] = temp_data * vs.maskT[2:-2, 2:-2, :]
        vs.temp[2:-2, 2:-2, :, 1] = temp_data * vs.maskT[2:-2, 2:-2, :]

        salt_data = veros.tools.interpolate((xt_forc, yt_forc, zt_forc), self._get_data(vs, 'salinity')[:, :, ::-1],
                                       t_grid, missing_value=0.)
        vs.salt[2:-2, 2:-2, :, 0] = salt_data * vs.maskT[2:-2, 2:-2, :]
        vs.salt[2:-2, 2:-2, :, 1] = salt_data * vs.maskT[2:-2, 2:-2, :]

        # wind stress on MIT grid
        time_grid = (vs.xt[2:-2], vs.yt[2:-2], np.arange(12))
        taux_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                      self._get_data(vs, 'tau_x'), time_grid,
                                      missing_value=0.)
        vs.taux[2:-2, 2:-2, :] = taux_data
        mask = np.logical_and(vs.yt > self.so_wind_interval[0], vs.yt < self.so_wind_interval[1])[..., np.newaxis]
        vs.taux *= 1. + mask * (self.so_wind_factor - 1.) * np.sin(np.pi * (vs.yt[np.newaxis, :, np.newaxis] - self.so_wind_interval[0]) \
                                                                            / (self.so_wind_interval[1] - self.so_wind_interval[0]))

        tauy_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                      self._get_data(vs, 'tau_y'), time_grid,
                                      missing_value=0.)
        vs.tauy[2:-2, 2:-2, :] = tauy_data

        enforce_boundaries(vs, vs.taux)
        enforce_boundaries(vs, vs.tauy)

        # Qnet and dQ/dT and Qsol
        qnet_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                      self._get_data(vs, 'q_net'), time_grid, missing_value=0.)
        vs.qnet[2:-2, 2:-2, :] = -qnet_data * vs.maskT[2:-2, 2:-2, -1, np.newaxis]

        qnec_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                       self._get_data(vs, 'dqdt'), time_grid, missing_value=0.)
        vs.qnec[2:-2, 2:-2, :] = qnec_data * vs.maskT[2:-2, 2:-2, -1, np.newaxis]

        qsol_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                       self._get_data(vs, 'swf'), time_grid, missing_value=0.)
        vs.qsol[2:-2, 2:-2, :] = -qsol_data * vs.maskT[2:-2, 2:-2, -1, np.newaxis]

        # SST and SSS
        sst_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                     self._get_data(vs, 'sst'), time_grid, missing_value=0.)
        vs.t_star[2:-2, 2:-2, :] = sst_data * vs.maskT[2:-2, 2:-2, -1, np.newaxis]

        sss_data = veros.tools.interpolate((xt_forc, yt_forc, np.arange(12)),
                                     self._get_data(vs, 'sss'), time_grid, missing_value=0.)
        vs.s_star[2:-2, 2:-2, :] = sss_data * vs.maskT[2:-2, 2:-2, -1, np.newaxis]

        if vs.enable_idemix:
            tidal_energy_data = veros.tools.interpolate(
                (xt_forc, yt_forc), self._get_data(vs, 'tidal_energy'), t_grid[:-1], missing_value=0.
            )
            mask_x, mask_y = (i + 2 for i in np.indices((vs.nx, vs.ny)))
            mask_z = np.maximum(0, vs.kbot[2:-2, 2:-2] - 1)
            tidal_energy_data[:, :] *= vs.maskW[mask_x, mask_y, mask_z] / vs.rho_0
            vs.forc_iw_bottom[2:-2, 2:-2] = tidal_energy_data

        # average variables in North Atlantic
        na_average_vars = [vs.taux, vs.tauy, vs.qnet, vs.qnec, vs.qsol,
                           vs.t_star, vs.s_star, vs.salt, vs.temp]

        for k in na_average_vars:
            k[2:-2, 2:-2, ...] = self._fix_north_atlantic(vs, k[2:-2, 2:-2, ...])

        """
        Initialize penetration profile for solar radiation and store divergence in divpen
        note that pen is set to 0.0 at the surface instead of 1.0 to compensate for the
        shortwave part of the total surface flux
        """
        swarg1 = vs.zw / efold1_shortwave
        swarg2 = vs.zw / efold2_shortwave
        pen = rpart_shortwave * np.exp(swarg1) + (1.0 - rpart_shortwave) * np.exp(swarg2)
        pen[-1] = 0.
        vs.divpen_shortwave[1:] = (pen[1:] - pen[:-1]) / vs.dzt[1:]
        vs.divpen_shortwave[0] = pen[0] / vs.dzt[0]
def carbon_flux(vs):
    """Calculates flux of CO2 over the ocean-atmosphere boundary

    This is an adaptation of co2_calc_SWS from UVic ESCM

    Note
    ----
    This was written without an atmosphere component in veros.
    Therefore an atmospheric pressure of 1 atm is assumed.
    The concentration of CO2 (in units of ppmv) may be set in vs.atmospheric_co2

    Note
    ----
    This was written without an explicit sea ice component. Therefore a full
    ice cover is assumed when temperature is below -1.8C and temperature forcing is negative

    Returns
    -------
    numpy.ndarray(vs.nx, vs.ny) with flux in units of :math:`mmol / m^2 / s`
    Positive indicates a flux into the ocean
    """

    icemask = np.logical_and(vs.temp[:, :, -1, vs.tau] * vs.maskT[:, :, -1] < -1.8,
            vs.forc_temp_surface < 0.0)
    ao = np.logical_not(icemask)

    atmospheric_pressure = 1  # atm  NOTE: We don't have an atmosphere yet, hence constant pressure

    # TODO get actual wind speed rather than deriving from wind stress
    wind_speed = np.sqrt(np.abs(vs.surface_taux / vs.rho_0) + np.abs(vs.surface_tauy / vs.rho_0)) * 500
    vs.wind_speed = wind_speed


    # xconv is constant to convert piston_vel from cm/hr -> cm/s
    # here it is 100.*a*xconv (100 => m to cm, a=0.337, xconv=1/3.6e+05)
    xconv = 0.337 / 3.6e5
    xconv *= 0.75  # NOTE: This seems like an approximation I don't know where they got it

    vs.dco2star = co2calc_SWS(vs, vs.temp[:, :, -1, vs.tau],  # [degree C]
                              vs.salt[:, :, -1, vs.tau],  # [g/kg]
                              vs.dic[:, :, -1, vs.tau] * 1e-3,  # [mmol -> mol]
                              vs.alkalinity[:, :, -1, vs.tau] * 1e-3,  # [mmol -> mol]
                              vs.atmospheric_co2,  # [ppmv]
                              atmospheric_pressure)  # atm

    # Schmidt number for CO2
    # Wanninkhof, 1992, table A
    scco2 = 2073.1 - 125.62 * vs.temp[:, :, -1, vs.tau] + 3.6276 * vs.temp[:, :, -1, vs.tau] ** 2 - 0.043219 * vs.temp[:, :, -1, vs.tau] ** 3

    piston_vel = ao * xconv * (wind_speed) ** 2 * ((scco2/660.0)**(-0.5))
    # NOTE: According to https://www.soest.hawaii.edu/oceanography/courses/OCN623/Spring%202015/Gas_Exchange_2015_one-lecture1.pdf there are 3 regimes we are looking at the wavy surface
    # NOTE: https://aslopubs.onlinelibrary.wiley.com/doi/pdf/10.4319/lom.2014.12.351 uses the correct form for scco2
    # Sweeney et al. 2007

    # 1e3 added to convert to mmol / m^2 / s
    co2_flux = piston_vel * vs.dco2star * vs.maskT[:, :, -1] * 1e3

    utilities.enforce_boundaries(vs, co2_flux)

    vs.cflux[...] = co2_flux
    return vs.cflux