コード例 #1
0
def load(name, params):
    dyn_vars = ['uz', 'ux', 'rho', 'P']
    snapshots_dir = SNAPSHOTS_DIR % name
    filename = '{s}/{s}_s1.h5'.format(s=snapshots_dir)
    start_idx = 0

    post.merge_analysis(snapshots_dir, cleanup=False)

    solver, domain = get_solver(params)
    z = domain.grid(1, scales=params['INTERP_Z'])

    with h5py.File(filename, mode='r') as dat:
        sim_times = np.array(dat['scales']['sim_time'])[start_idx:]
    # we let the file close before trying to reopen it again in load

    # load into state_vars
    state_vars = defaultdict(list)
    for idx in range(len(sim_times)):
        solver.load_state(filename, idx + start_idx)

        for varname in dyn_vars:
            values = solver.state[varname]
            values.set_scales((params['INTERP_X'], params['INTERP_Z']),
                              keep_data=True)
            state_vars[varname].append(np.copy(values['g']))
            state_vars['%s_c' % varname].append(np.copy(np.abs(values['c'])))
    # cast to np arrays
    for key in state_vars.keys():
        state_vars[key] = np.array(state_vars[key])

    state_vars['rho'] += params['RHO0'] * np.exp(-z / params['H'])
    state_vars['P'] += params['RHO0'] * (np.exp(-z / params['H']) - 1) *\
        params['g'] * params['H']
    state_vars['rho1'] = state_vars['rho'] - params['RHO0'] *\
        np.exp(-z / params['H'])
    state_vars['P1'] = state_vars['P'] -\
        params['RHO0'] * (np.exp(-z / params['H']) - 1) *\
        params['g'] * params['H']

    state_vars['E'] = params['RHO0'] * np.exp(-z / params['H']) * \
                       (state_vars['ux']**2 + state_vars['uz']**2) / 2
    state_vars['F_z'] = state_vars['uz'] * (
        state_vars['rho'] *
        (state_vars['ux']**2 + state_vars['uz']**2) + state_vars['P'])
    return sim_times, domain, state_vars
コード例 #2
0
def merge(name):
    snapshots_dir = SNAPSHOTS_DIR % name
    dir_expr = '{s}/{s}_s{idx}'

    idx = 1
    to_merge = False

    snapshots_piece_dir = dir_expr.format(s=snapshots_dir, idx=idx)
    filename = FILENAME_EXPR.format(s=snapshots_dir, idx=idx)

    while os.path.exists(snapshots_piece_dir):
        if not os.path.exists(filename):
            to_merge = True
        idx += 1
        snapshots_piece_dir = dir_expr.format(s=snapshots_dir, idx=idx)
        filename = FILENAME_EXPR.format(s=snapshots_dir, idx=idx)

    if to_merge:
        post.merge_analysis(snapshots_dir)
コード例 #3
0
ファイル: hartmann.py プロジェクト: jsoishi/hartmann_flow
timeseries.add_task("0.5*integ(Bx**2 + By**2)", name='Emag')
analysis_tasks.append(timeseries)

# Flow properties
flow = flow_tools.GlobalFlowProperty(solver, cadence=1)
flow.add_property("0.5*(vx**2 + vy**2)", name='Ekin')

try:
    logger.info('Starting loop')
    start_run_time = time.time()

    while solver.ok:
        solver.step(dt)
        if (solver.iteration - 1) % 1 == 0:
            logger.info('Iteration: %i, Time: %e, dt: %e' %
                        (solver.iteration, solver.sim_time, dt))
            logger.info('Max E_kin = %17.12e' % flow.max('Ekin'))
except:
    logger.error('Exception raised, triggering end of main loop.')
    raise
finally:
    end_run_time = time.time()
    logger.info('Iterations: %i' % solver.iteration)
    logger.info('Sim end time: %f' % solver.sim_time)
    logger.info('Run time: %f' % (end_run_time - start_run_time))

logger.info('beginning join operation')
for task in analysis_tasks:
    logger.info(task.base_path)
    post.merge_analysis(task.base_path)
コード例 #4
0
def Rayleigh_Benard(Rayleigh=1e6,
                    Prandtl=1,
                    nz=64,
                    nx=None,
                    ny=None,
                    aspect=4,
                    fixed_flux=False,
                    fixed_T=False,
                    mixed_flux_T=True,
                    stress_free=False,
                    no_slip=True,
                    restart=None,
                    run_time=23.5,
                    run_time_buoyancy=None,
                    run_time_iter=np.inf,
                    run_time_therm=1,
                    max_writes=20,
                    max_slice_writes=20,
                    output_dt=0.2,
                    data_dir='./',
                    coeff_output=True,
                    verbose=False,
                    no_join=False,
                    do_bvp=False,
                    num_bvps=10,
                    bvp_convergence_factor=1e-3,
                    bvp_equil_time=10,
                    bvp_resolution_factor=1,
                    bvp_transient_time=30,
                    bvp_final_equil_time=None,
                    min_bvp_time=50,
                    first_bvp_time=20,
                    first_bvp_convergence_factor=1e-2,
                    threeD=False,
                    seed=42,
                    mesh=None,
                    overwrite=False):
    import os
    from dedalus.tools.config import config

    config['logging']['filename'] = os.path.join(data_dir, 'logs/dedalus_log')
    config['logging']['file_level'] = 'DEBUG'

    import mpi4py.MPI
    if mpi4py.MPI.COMM_WORLD.rank == 0:
        if not os.path.exists('{:s}/'.format(data_dir)):
            os.makedirs('{:s}/'.format(data_dir))
        logdir = os.path.join(data_dir, 'logs')
        if not os.path.exists(logdir):
            os.mkdir(logdir)
    logger = logging.getLogger(__name__)
    logger.info("saving run in: {}".format(data_dir))

    import time
    from dedalus import public as de
    from dedalus.extras import flow_tools
    from dedalus.tools import post

    # input parameters
    logger.info("Ra = {}, Pr = {}".format(Rayleigh, Prandtl))

    # Parameters
    Lz = 1.
    Lx = aspect * Lz
    Ly = aspect * Lz
    if nx is None:
        nx = int(nz * aspect)
    if ny is None:
        ny = int(nz * aspect)

    if threeD:
        logger.info("resolution: [{}x{}x{}]".format(nx, ny, nz))
        equations = BoussinesqEquations3D(nx=nx,
                                          ny=ny,
                                          nz=nz,
                                          Lx=Lx,
                                          Ly=Ly,
                                          Lz=Lz,
                                          mesh=mesh)
    else:
        logger.info("resolution: [{}x{}]".format(nx, nz))
        equations = BoussinesqEquations2D(nx=nx, nz=nz, Lx=Lx, Lz=Lz)
    equations.set_IVP(Rayleigh, Prandtl)

    bc_dict = {
        'fixed_flux': None,
        'fixed_temperature': None,
        'mixed_flux_temperature': None,
        'mixed_temperature_flux': None,
        'stress_free': None,
        'no_slip': None
    }
    if mixed_flux_T:
        bc_dict['mixed_flux_temperature'] = True
    elif fixed_T:
        bc_dict['fixed_temperature'] = True
    elif fixed_flux:
        bc_dict['fixed_flux'] = True

    if stress_free:
        bc_dict['stress_free'] = True
    elif no_slip:
        bc_dict['no_slip'] = True

    supercrit = Rayleigh / RA_CRIT

    equations.set_BC(**bc_dict)

    # Build solver
    ts = de.timesteppers.RK443
    cfl_safety = 0.8

    solver = equations.problem.build_solver(ts)
    logger.info('Solver built')

    checkpoint = Checkpoint(data_dir)
    if isinstance(restart, type(None)):
        equations.set_IC(solver, seed=seed)
        dt = None
        mode = 'overwrite'
    else:
        logger.info("restarting from {}".format(restart))
        checkpoint.restart(restart, solver)
        if overwrite:
            mode = 'overwrite'
        else:
            mode = 'append'
    checkpoint.set_checkpoint(solver, wall_dt=checkpoint_min * 60, mode=mode)

    # Integration parameters
    #    if not isinstance(run_time_therm, type(None)):
    #        solver.stop_sim_time = run_time_therm*equations.thermal_time + solver.sim_time
    #    elif not isinstance(run_time_buoyancy, type(None)):
    #        solver.stop_sim_time  = run_time_buoyancy + solver.sim_time
    #    else:
    solver.stop_sim_time = np.inf
    solver.stop_wall_time = run_time * 3600.
    solver.stop_iteration = run_time_iter
    Hermitian_cadence = 100

    # Analysis
    max_dt = output_dt
    analysis_tasks = equations.initialize_output(solver,
                                                 data_dir,
                                                 coeff_output=coeff_output,
                                                 output_dt=output_dt,
                                                 mode=mode)

    # CFL
    CFL = flow_tools.CFL(solver,
                         initial_dt=0.1,
                         cadence=1,
                         safety=cfl_safety,
                         max_change=1.5,
                         min_change=0.5,
                         max_dt=max_dt,
                         threshold=0.1)
    if threeD:
        CFL.add_velocities(('u', 'v', 'w'))
    else:
        CFL.add_velocities(('u', 'w'))

    # Flow properties
    flow = flow_tools.GlobalFlowProperty(solver, cadence=1)
    flow.add_property("Re", name='Re')
    #    flow.add_property("interp(w, z=0.95)", name='w near top')

    #    u, v, w = solver.state['u'], solver.state['v'], solver.state['w']

    if do_bvp:
        if not threeD:
            ny = 0
            atmo_class = BoussinesqEquations2D
        else:
            atmo_class = BoussinesqEquations3D
        bvp_solver = BoussinesqBVPSolver(atmo_class, nx, ny, nz, \
                                   flow, equations.domain.dist.comm_cart, \
                                   solver, num_bvps, bvp_equil_time, \
                                   threeD=threeD,
                                   bvp_transient_time=bvp_transient_time, \
                                   bvp_run_threshold=bvp_convergence_factor, \
                                   bvp_l2_check_time=1, mesh=mesh,\
                                   first_bvp_time=first_bvp_time,
                                   first_run_threshold=first_bvp_convergence_factor,\
                                   plot_dir='{}/bvp_plots/'.format(data_dir),\
                                   min_avg_dt=1e-10, final_equil_time=bvp_final_equil_time,
                                   min_bvp_time=min_bvp_time)
        bc_dict.pop('stress_free')
        bc_dict.pop('no_slip')

#   print(equations.domain.grid(0), equations.domain.grid(1), equations.domain.grid(2))

    first_step = True
    # Main loop
    try:
        logger.info('Starting loop')
        Re_avg = 0
        continue_bvps = True
        not_corrected_times = True
        init_time = solver.sim_time
        start_iter = solver.iteration
        while (solver.ok and np.isfinite(Re_avg)) and continue_bvps:
            dt = CFL.compute_dt()
            solver.step(dt)  #, trim=True)

            # Solve for blow-up over long timescales in 3D due to hermitian-ness
            effective_iter = solver.iteration - start_iter
            if threeD and effective_iter % Hermitian_cadence == 0:
                for field in solver.state.fields:
                    field.require_grid_space()

            Re_avg = flow.grid_average('Re')
            log_string = 'Iteration: {:5d}, '.format(solver.iteration)
            log_string += 'Time: {:8.3e} ({:8.3e} therm), dt: {:8.3e}, '.format(
                solver.sim_time, solver.sim_time / equations.thermal_time, dt)
            log_string += 'Re: {:8.3e}/{:8.3e}'.format(Re_avg, flow.max('Re'))
            logger.info(log_string)

            if not_corrected_times and Re_avg > 1:
                if not isinstance(run_time_therm, type(None)):
                    solver.stop_sim_time = run_time_therm * equations.thermal_time + solver.sim_time
                elif not isinstance(run_time_buoyancy, type(None)):
                    solver.stop_sim_time = run_time_buoyancy + solver.sim_time
                not_corrected_times = False

            if do_bvp:
                bvp_solver.update_avgs(dt, Re_avg, min_Re=np.sqrt(supercrit))
                if bvp_solver.check_if_solve():
                    atmo_kwargs = {'nz': nz * bvp_resolution_factor, 'Lz': Lz}
                    diff_args = [Rayleigh, Prandtl]
                    bvp_solver.solve_BVP(atmo_kwargs, diff_args, bc_dict)
                if bvp_solver.terminate_IVP():
                    continue_bvps = False

            if first_step:
                if verbose:
                    import matplotlib
                    matplotlib.use('Agg')
                    import matplotlib.pyplot as plt
                    fig = plt.figure()
                    ax = fig.add_subplot(1, 1, 1)
                    ax.spy(solver.pencils[0].L,
                           markersize=1,
                           markeredgewidth=0.0)
                    fig.savefig(data_dir + "sparsity_pattern.png", dpi=1200)

                    import scipy.sparse.linalg as sla
                    LU = sla.splu(solver.pencils[0].LHS.tocsc(),
                                  permc_spec='NATURAL')
                    fig = plt.figure()
                    ax = fig.add_subplot(1, 2, 1)
                    ax.spy(LU.L.A, markersize=1, markeredgewidth=0.0)
                    ax = fig.add_subplot(1, 2, 2)
                    ax.spy(LU.U.A, markersize=1, markeredgewidth=0.0)
                    fig.savefig(data_dir + "sparsity_pattern_LU.png", dpi=1200)

                    logger.info("{} nonzero entries in LU".format(LU.nnz))
                    logger.info("{} nonzero entries in LHS".format(
                        solver.pencils[0].LHS.tocsc().nnz))
                    logger.info("{} fill in factor".format(
                        LU.nnz / solver.pencils[0].LHS.tocsc().nnz))
                first_step = False
                start_time = time.time()
    except:
        raise
        logger.error('Exception raised, triggering end of main loop.')
    finally:
        end_time = time.time()
        main_loop_time = end_time - start_time
        n_iter_loop = solver.iteration - 1
        logger.info('Iterations: {:d}'.format(n_iter_loop))
        logger.info('Sim end time: {:f}'.format(solver.sim_time))
        logger.info('Run time: {:f} sec'.format(main_loop_time))
        logger.info('Run time: {:f} cpu-hr'.format(
            main_loop_time / 60 / 60 * equations.domain.dist.comm_cart.size))
        logger.info('iter/sec: {:f} (main loop only)'.format(n_iter_loop /
                                                             main_loop_time))
        try:
            final_checkpoint = Checkpoint(data_dir,
                                          checkpoint_name='final_checkpoint')
            final_checkpoint.set_checkpoint(solver, wall_dt=1, mode=mode)
            solver.step(dt)  #clean this up in the future...works for now.
            post.merge_process_files(data_dir + '/final_checkpoint/',
                                     cleanup=False)
        except:
            raise
            print('cannot save final checkpoint')
        finally:
            if not no_join:
                logger.info('beginning join operation')
                post.merge_analysis(data_dir + 'checkpoints')

                for task in analysis_tasks:
                    logger.info(task.base_path)
                    post.merge_analysis(task.base_path)

            logger.info(40 * "=")
            logger.info('Iterations: {:d}'.format(n_iter_loop))
            logger.info('Sim end time: {:f}'.format(solver.sim_time))
            logger.info('Run time: {:f} sec'.format(main_loop_time))
            logger.info('Run time: {:f} cpu-hr'.format(
                main_loop_time / 60 / 60 *
                equations.domain.dist.comm_cart.size))
            logger.info('iter/sec: {:f} (main loop only)'.format(
                n_iter_loop / main_loop_time))
                                               domain.dist.comm_cart.size))
    logger.info('iter/sec: {:f} (main loop only)'.format(n_iter_loop /
                                                         main_loop_time))
    try:
        final_checkpoint = Checkpoint(data_dir,
                                      checkpoint_name='final_checkpoint')
        final_checkpoint.set_checkpoint(solver, wall_dt=1, mode=mode)
        solver.step(dt)  #clean this up in the future...works for now.
        post.merge_process_files(data_dir + '/final_checkpoint/',
                                 cleanup=False)
    except:
        raise
        print('cannot save final checkpoint')
    finally:
        if not args['--no_join']:
            logger.info('beginning join operation')
            post.merge_analysis(data_dir + 'checkpoint')

            for key, task in analysis_tasks.items():
                logger.info(task.base_path)
                post.merge_analysis(task.base_path)

        logger.info(40 * "=")
        logger.info('Iterations: {:d}'.format(n_iter_loop))
        logger.info('Sim end time: {:f}'.format(solver.sim_time))
        logger.info('Run time: {:f} sec'.format(main_loop_time))
        logger.info('Run time: {:f} cpu-hr'.format(main_loop_time / 60 / 60 *
                                                   domain.dist.comm_cart.size))
        logger.info('iter/sec: {:f} (main loop only)'.format(n_iter_loop /
                                                             main_loop_time))
コード例 #6
0
import os
import sys
import logging

logger = logging.getLogger(__name__)

import dedalus.public
from dedalus.tools import post

data_dir = sys.argv[1]
base_path = os.path.abspath(data_dir) + '/'

logger.info("joining data from Dedalus run {:s}".format(data_dir))

data_types = [
    'checkpoint', 'scalar', 'profiles', 'slices', 'slices_small', 'slices_T',
    'coeffs'
]

for data_type in data_types:
    logger.info("merging {}".format(data_type))
    try:
        post.merge_analysis(base_path + data_type)
    except:
        logger.info("missing {}".format(data_type))

logger.info("done join operation for {:s}".format(data_dir))
コード例 #7
0
"""
Merge distributed analysis sets from a FileHandler.

Usage:
    merge.py <base_path> [--cleanup]

Options:
    --cleanup   Delete distributed files after merging

"""

if __name__ == "__main__":

    from docopt import docopt
    from dedalus.tools import logging
    from dedalus.tools import post

    args = docopt(__doc__)
    post.merge_analysis(args['<base_path>'], cleanup=args['--cleanup'])

コード例 #8
0
    --cleanup   Delete distributed files after merging

"""

import h5py
import h5py
import subprocess
from dedalus.tools import post
import pathlib
import sys
import os
import argparse
import glob
from docopt import docopt

args = docopt(__doc__)
base = args['<base_path>']
iter_paths = os.listdir(args['<base_path>'])

for path in iter_paths:
    print(base + "/" + path)
    folder = base + "/" + path
    post.merge_analysis(folder)

for path in iter_paths:
    print(base + "/" + path)
    set_paths = list(glob.glob(base + "/" + path + "/*h5"))
    print(set_paths)
    post.merge_sets(base + "/" + path + ".h5", set_paths, cleanup=False)
    #~ print(file_name)
コード例 #9
0
        imagestack.timestring.set_text(tstr)
        imagestack.write(output_path, output_name, i_fig)
        imagestack.close()


if __name__ == "__main__":

    import pathlib
    from docopt import docopt
    from dedalus.tools import logging
    from dedalus.tools import post
    from dedalus.tools.parallel import Sync

    args = docopt(__doc__)
    if args['join']:
        post.merge_analysis(args['<base_path>'])
    else:
        if args['--output'] is not None:
            output_path = pathlib.Path(args['--output']).absolute()
        else:
            data_dir = args['<files>'][0].split('/')[0]
            data_dir += '/'
            output_path = pathlib.Path(data_dir).absolute()
        # Create output directory if needed
        with Sync() as sync:
            if sync.comm.rank == 0:
                if not output_path.exists():
                    output_path.mkdir()
        fields = args['--fields'].split(',')
        logger.info("output to {}".format(output_path))
コード例 #10
0
def Rayleigh_Benard(Rayleigh=1e6,
                    Prandtl=1,
                    nz=64,
                    nx=None,
                    aspect=4,
                    fixed_flux=False,
                    fixed_T=True,
                    viscous_heating=False,
                    restart=None,
                    run_time=23.5,
                    run_time_buoyancy=50,
                    run_time_iter=np.inf,
                    max_writes=10,
                    max_slice_writes=10,
                    data_dir='./',
                    coeff_output=True,
                    verbose=False,
                    no_join=False):
    import os
    from dedalus.tools.config import config

    config['logging']['filename'] = os.path.join(data_dir, 'logs/dedalus_log')
    config['logging']['file_level'] = 'DEBUG'

    import mpi4py.MPI
    if mpi4py.MPI.COMM_WORLD.rank == 0:
        if not os.path.exists('{:s}/'.format(data_dir)):
            os.makedirs('{:s}/'.format(data_dir))
        logdir = os.path.join(data_dir, 'logs')
        if not os.path.exists(logdir):
            os.mkdir(logdir)
    logger = logging.getLogger(__name__)
    logger.info("saving run in: {}".format(data_dir))

    import time
    from dedalus import public as de
    from dedalus.extras import flow_tools
    from dedalus.tools import post

    # input parameters
    logger.info("Ra = {}, Pr = {}".format(Rayleigh, Prandtl))

    # Parameters
    Lz = 1.
    Lx = aspect * Lz

    if nx is None:
        nx = int(nz * aspect)

    logger.info("resolution: [{}x{}]".format(nx, nz))
    # Create bases and domain
    x_basis = de.Fourier('x', nx, interval=(0, Lx), dealias=3 / 2)
    z_basis = de.Chebyshev('z', nz, interval=(0, Lz), dealias=3 / 2)
    domain = de.Domain([x_basis, z_basis], grid_dtype=np.float64)

    if fixed_flux:
        T_bc_var = 'Tz'
    elif fixed_T:
        T_bc_var = 'T'

    # 2D Boussinesq hydrodynamics
    problem = de.IVP(domain, variables=['Tz', 'T', 'p', 'u', 'w', 'Oy'])
    problem.meta['p', T_bc_var, 'Oy', 'w']['z']['dirichlet'] = True

    T0_z = domain.new_field()
    T0_z.meta['x']['constant'] = True
    T0_z['g'] = -1
    T0 = domain.new_field()
    T0.meta['x']['constant'] = True
    T0['g'] = Lz / 2 - domain.grid(-1)
    problem.parameters['T0'] = T0
    problem.parameters['T0_z'] = T0_z

    problem.parameters['P'] = (Rayleigh * Prandtl)**(-1 / 2)
    problem.parameters['R'] = (Rayleigh / Prandtl)**(-1 / 2)
    problem.parameters['F'] = F = 1

    problem.parameters['Lx'] = Lx
    problem.parameters['Lz'] = Lz
    problem.substitutions['plane_avg(A)'] = 'integ(A, "x")/Lx'
    problem.substitutions['vol_avg(A)'] = 'integ(A)/Lx/Lz'

    problem.substitutions['v'] = '0'
    problem.substitutions['Ox'] = '0'
    problem.substitutions['Oz'] = '(dx(v) )'
    problem.substitutions['Kx'] = '( -dz(Oy))'
    problem.substitutions['Ky'] = '(dz(Ox) - dx(Oz))'
    problem.substitutions['Kz'] = '(dx(Oy) )'

    problem.substitutions['vorticity'] = 'Oy'
    problem.substitutions['enstrophy'] = 'Oy**2'

    problem.substitutions['u_fluc'] = '(u - plane_avg(u))'
    problem.substitutions['w_fluc'] = '(w - plane_avg(w))'
    problem.substitutions['KE'] = '(0.5*(u*u+w*w))'

    problem.substitutions['sigma_xz'] = '(dx(w) + Oy + dx(w))'
    problem.substitutions['sigma_xx'] = '(2*dx(u))'
    problem.substitutions['sigma_zz'] = '(2*dz(w))'

    if viscous_heating:
        problem.substitutions[
            'visc_heat'] = 'R*(sigma_xz**2 + sigma_xx*dx(u) + sigma_zz*dz(w))'
        problem.substitutions['visc_flux_z'] = 'R*(u*sigma_xz + w*sigma_zz)'
    else:
        problem.substitutions['visc_heat'] = '0'
        problem.substitutions['visc_flux_z'] = '0'

    problem.substitutions['conv_flux_z'] = '(w*T + visc_flux_z)/P'
    problem.substitutions['kappa_flux_z'] = '(-Tz)'

    problem.add_equation("Tz - dz(T) = 0")
    problem.add_equation(
        "dt(T) - P*(dx(dx(T)) + dz(Tz)) + w*T0_z    = -(u*dx(T) + w*Tz)  - visc_heat"
    )
    # O == omega = curl(u);  K = curl(O)
    problem.add_equation("dt(u)  + R*Kx  + dx(p)              =  v*Oz - w*Oy ")
    problem.add_equation("dt(w)  + R*Kz  + dz(p)    -T        =  u*Oy - v*Ox ")
    problem.add_equation("dx(u) + dz(w) = 0")
    problem.add_equation("Oy - dz(u) + dx(w) = 0")

    problem.add_bc("right(p) = 0", condition="(nx == 0)")
    if fixed_flux:
        problem.add_bc("left(Tz)  = 0")
        problem.add_bc("right(Tz) = 0")
    elif fixed_T:
        problem.add_bc("left(T)  = 0")
        problem.add_bc("right(T) = 0")
    problem.add_bc("left(Oy) = 0")
    problem.add_bc("right(Oy) = 0")
    problem.add_bc("left(w)  = 0")
    problem.add_bc("right(w) = 0", condition="(nx != 0)")

    # Build solver
    ts = de.timesteppers.RK443
    cfl_safety = 1

    solver = problem.build_solver(ts)
    logger.info('Solver built')

    checkpoint = solver.evaluator.add_file_handler(data_dir + 'checkpoints',
                                                   wall_dt=8 * 3600,
                                                   max_writes=1)
    checkpoint.add_system(solver.state, layout='c')

    # Initial conditions
    x = domain.grid(0)
    z = domain.grid(1)
    T = solver.state['T']
    Tz = solver.state['Tz']

    # Random perturbations, initialized globally for same results in parallel
    noise = global_noise(domain, scale=1, frac=0.25)

    if restart is None:
        # Linear background + perturbations damped at walls
        zb, zt = z_basis.interval
        pert = 1e-3 * noise * (zt - z) * (z - zb)
        T['g'] = pert
        T.differentiate('z', out=Tz)
    else:
        logger.info("restarting from {}".format(restart))
        checkpoint.restart(restart, solver)

    # Integration parameters
    solver.stop_sim_time = run_time_buoyancy
    solver.stop_wall_time = run_time * 3600.
    solver.stop_iteration = run_time_iter

    # Analysis
    analysis_tasks = []
    snapshots = solver.evaluator.add_file_handler(data_dir + 'slices',
                                                  sim_dt=0.1,
                                                  max_writes=max_slice_writes)
    snapshots.add_task("T + T0", name='T')
    snapshots.add_task("enstrophy")
    snapshots.add_task("vorticity")
    analysis_tasks.append(snapshots)

    snapshots_small = solver.evaluator.add_file_handler(
        data_dir + 'slices_small', sim_dt=0.1, max_writes=max_slice_writes)
    snapshots_small.add_task("T + T0", name='T', scales=0.25)
    snapshots_small.add_task("enstrophy", scales=0.25)
    snapshots_small.add_task("vorticity", scales=0.25)
    analysis_tasks.append(snapshots_small)

    if coeff_output:
        coeffs = solver.evaluator.add_file_handler(data_dir + 'coeffs',
                                                   sim_dt=0.1,
                                                   max_writes=max_slice_writes)
        coeffs.add_task("T+T0", name="T", layout='c')
        coeffs.add_task("T - plane_avg(T)", name="T'", layout='c')
        coeffs.add_task("w", layout='c')
        coeffs.add_task("u", layout='c')
        coeffs.add_task("enstrophy", layout='c')
        coeffs.add_task("vorticity", layout='c')
        analysis_tasks.append(coeffs)

    profiles = solver.evaluator.add_file_handler(data_dir + 'profiles',
                                                 sim_dt=0.1,
                                                 max_writes=max_writes)
    profiles.add_task("plane_avg(T+T0)", name="T")
    profiles.add_task("plane_avg(T)", name="T'")
    profiles.add_task("plane_avg(u)", name="u")
    profiles.add_task("plane_avg(w)", name="w")
    profiles.add_task("plane_avg(enstrophy)", name="enstrophy")
    # This may have an error:
    profiles.add_task("plane_avg(conv_flux_z)/plane_avg(kappa_flux_z) + 1",
                      name="Nu")
    profiles.add_task("plane_avg(conv_flux_z) + plane_avg(kappa_flux_z)",
                      name="Nu_2")

    analysis_tasks.append(profiles)

    scalar = solver.evaluator.add_file_handler(data_dir + 'scalar',
                                               sim_dt=0.1,
                                               max_writes=max_writes)
    scalar.add_task("vol_avg(T)", name="IE")
    scalar.add_task("vol_avg(KE)", name="KE")
    scalar.add_task("vol_avg(-T*z)", name="PE_fluc")
    scalar.add_task("vol_avg(T) + vol_avg(KE) + vol_avg(-T*z)", name="TE")
    scalar.add_task("0.5*vol_avg(u_fluc*u_fluc+w_fluc*w_fluc)", name="KE_fluc")
    scalar.add_task("0.5*vol_avg(u*u)", name="KE_x")
    scalar.add_task("0.5*vol_avg(w*w)", name="KE_z")
    scalar.add_task("0.5*vol_avg(u_fluc*u_fluc)", name="KE_x_fluc")
    scalar.add_task("0.5*vol_avg(w_fluc*w_fluc)", name="KE_z_fluc")
    scalar.add_task("vol_avg(plane_avg(u)**2)", name="u_avg")
    scalar.add_task("vol_avg((u - plane_avg(u))**2)", name="u1")
    scalar.add_task("vol_avg(conv_flux_z) + 1.", name="Nu")
    analysis_tasks.append(scalar)

    # workaround for issue #29
    problem.namespace['enstrophy'].store_last = True

    # CFL
    CFL = flow_tools.CFL(solver,
                         initial_dt=0.1,
                         cadence=1,
                         safety=cfl_safety,
                         max_change=1.5,
                         min_change=0.5,
                         max_dt=0.1,
                         threshold=0.1)
    CFL.add_velocities(('u', 'w'))

    # Flow properties
    flow = flow_tools.GlobalFlowProperty(solver, cadence=1)
    flow.add_property("sqrt(u*u + w*w) / R", name='Re')

    first_step = True
    # Main loop
    try:
        logger.info('Starting loop')
        Re_avg = 0
        while solver.ok and np.isfinite(Re_avg):
            dt = CFL.compute_dt()
            solver.step(dt)  #, trim=True)
            Re_avg = flow.grid_average('Re')
            log_string = 'Iteration: {:5d}, '.format(solver.iteration)
            log_string += 'Time: {:8.3e}, dt: {:8.3e}, '.format(
                solver.sim_time, dt)
            log_string += 'Re: {:8.3e}/{:8.3e}'.format(Re_avg, flow.max('Re'))
            logger.info(log_string)

            if first_step:
                if verbose:
                    import matplotlib
                    matplotlib.use('Agg')
                    import matplotlib.pyplot as plt
                    fig = plt.figure()
                    ax = fig.add_subplot(1, 1, 1)
                    ax.spy(solver.pencils[0].L,
                           markersize=1,
                           markeredgewidth=0.0)
                    fig.savefig(data_dir + "sparsity_pattern.png", dpi=1200)

                    import scipy.sparse.linalg as sla
                    LU = sla.splu(solver.pencils[0].LHS.tocsc(),
                                  permc_spec='NATURAL')
                    fig = plt.figure()
                    ax = fig.add_subplot(1, 2, 1)
                    ax.spy(LU.L.A, markersize=1, markeredgewidth=0.0)
                    ax = fig.add_subplot(1, 2, 2)
                    ax.spy(LU.U.A, markersize=1, markeredgewidth=0.0)
                    fig.savefig(data_dir + "sparsity_pattern_LU.png", dpi=1200)

                    logger.info("{} nonzero entries in LU".format(LU.nnz))
                    logger.info("{} nonzero entries in LHS".format(
                        solver.pencils[0].LHS.tocsc().nnz))
                    logger.info("{} fill in factor".format(
                        LU.nnz / solver.pencils[0].LHS.tocsc().nnz))
                first_step = False
                start_time = time.time()
    except:
        logger.error('Exception raised, triggering end of main loop.')
        raise
    finally:
        end_time = time.time()
        main_loop_time = end_time - start_time
        n_iter_loop = solver.iteration - 1
        logger.info('Iterations: {:d}'.format(n_iter_loop))
        logger.info('Sim end time: {:f}'.format(solver.sim_time))
        logger.info('Run time: {:f} sec'.format(main_loop_time))
        logger.info('Run time: {:f} cpu-hr'.format(main_loop_time / 60 / 60 *
                                                   domain.dist.comm_cart.size))
        logger.info('iter/sec: {:f} (main loop only)'.format(n_iter_loop /
                                                             main_loop_time))

        final_checkpoint = solver.evaluator.add_file_handler(
            data_dir + 'final_checkpoint', iter=1)
        final_checkpoint.add_system(solver.state, layout='c')
        solver.step(dt)  #clean this up in the future...works for now.
        post.merge_analysis(data_dir + 'final_checkpoint')

        if not no_join:
            logger.info('beginning join operation')
            post.merge_analysis(data_dir + 'checkpoints')

            for task in analysis_tasks:
                logger.info(task.base_path)
                post.merge_analysis(task.base_path)

        logger.info(40 * "=")
        logger.info('Iterations: {:d}'.format(n_iter_loop))
        logger.info('Sim end time: {:f}'.format(solver.sim_time))
        logger.info('Run time: {:f} sec'.format(main_loop_time))
        logger.info('Run time: {:f} cpu-hr'.format(main_loop_time / 60 / 60 *
                                                   domain.dist.comm_cart.size))
        logger.info('iter/sec: {:f} (main loop only)'.format(n_iter_loop /
                                                             main_loop_time))
コード例 #11
0
KZ = -2 * np.pi / H

if __name__ == '__main__':
    name = 'lin'
    slice_suffix = '(x=0)'
    SAVE_FMT_STR = 't_%03d.png'
    snapshots_dir = 'snapshots_lin'
    path = '{s}/{s}_s1'.format(s=snapshots_dir)

    dyn_vars = ['uz', 'ux', 'W', 'U']

    n_cols = 2
    n_rows = 2
    filename = '{s}/{s}_s1.h5'.format(s=snapshots_dir)

    post.merge_analysis(snapshots_dir, cleanup=False)

    x_basis = de.Fourier('x',
                         N_X,
                         interval=(0, XMAX),
                         dealias=3/2)
    z_basis = de.Chebyshev('z',
                           N_Z,
                           interval=(0, ZMAX),
                           dealias=3/2)
    domain = de.Domain([x_basis, z_basis], np.float64)
    z = domain.grid(1)

    state_vars = defaultdict(list)
    with h5py.File(filename, mode='r') as dat:
        sim_times = np.array(dat['scales']['sim_time'])
コード例 #12
0
def merge(name):
    snapshots_dir = SNAPSHOTS_DIR % name
    filename = '{s}/{s}_s1.h5'.format(s=snapshots_dir)

    if not os.path.exists(filename):
        post.merge_analysis(snapshots_dir)
コード例 #13
0
    def solve_IVP(self,
                  dt,
                  CFL,
                  data_dir,
                  analysis_tasks,
                  task_args=(),
                  pre_loop_args=(),
                  task_kwargs={},
                  pre_loop_kwargs={},
                  time_div=None,
                  track_fields=['Pe'],
                  threeD=False,
                  Hermitian_cadence=100,
                  no_join=False,
                  mode='append'):
        """Logic for a while-loop that solves an initial value problem.

        Parameters
        ----------
        dt                  : float
            The initial timestep of the simulation
        CFL                 : a Dedalus CFL object
            A CFL object that calculates the timestep of the simulation on the fly
        data_dir            : string
            The parent directory of output files
        analysis_tasks      : OrderedDict()
            An OrderedDict of dedalus FileHandler objects
        task_args, task_kwargs : list, dict, optional
            arguments & keyword arguments to the self._special_tasks() function
        pre_loop_args, pre_loop_kwargs: list, dict, optional
            arguments & keyword arguments to the self.pre_loop_setup() function
        time_div            : float, optional
            A siulation time to divide the normal time by for easier output tracking
        threeD              : bool, optional
            If True, occasionally force the solution to grid space to remove Hermitian errors
        Hermitian_cadence   : int, optional
            The number of timesteps between grid space forcings in 3D.
        no_join             : bool, optional
            If True, do not join files at the end of the simulation run.
        mode                : string, optional
            Dedalus output mode for final checkpoint. "append" or "overwrite"
        args, kwargs        : list and dictionary
            Additional arguments and keyword arguments to be passed to the self.special_tasks() function
        """

        # Flow properties
        self.flow = flow_tools.GlobalFlowProperty(self.solver, cadence=1)
        for f in track_fields:
            self.flow.add_property(f, name=f)

        self.pre_loop_setup(*pre_loop_args, **pre_loop_kwargs)

        start_time = time.time()
        # Main loop
        count = 0
        try:
            logger.info('Starting loop')
            init_time = self.solver.sim_time
            start_iter = self.solver.iteration
            while (self.solver.ok):
                dt = CFL.compute_dt()
                self.solver.step(dt)  #, trim=True)

                # prevents blow-up over long timescales in 3D due to hermitian-ness
                effective_iter = self.solver.iteration - start_iter
                if threeD and effective_iter % Hermitian_cadence == 0:
                    for field in self.solver.state.fields:
                        field.require_grid_space()

                self.special_tasks(*task_args, **task_kwargs)

                #reporting string
                self.iteration_report(dt, track_fields, time_div=time_div)

                if not np.isfinite(self.flow.grid_average(track_fields[0])):
                    break
        except:
            raise
            logger.error('Exception raised, triggering end of main loop.')
        finally:
            end_time = time.time()
            main_loop_time = end_time - start_time
            n_iter_loop = self.solver.iteration - 1
            logger.info('Iterations: {:d}'.format(n_iter_loop))
            logger.info('Sim end time: {:f}'.format(self.solver.sim_time))
            logger.info('Run time: {:f} sec'.format(main_loop_time))
            logger.info('Run time: {:f} cpu-hr'.format(
                main_loop_time / 60 / 60 *
                self.de_domain.domain.dist.comm_cart.size))
            logger.info('iter/sec: {:f} (main loop only)'.format(
                n_iter_loop / main_loop_time))
            try:
                final_checkpoint = Checkpoint(
                    data_dir, checkpoint_name='final_checkpoint')
                final_checkpoint.set_checkpoint(self.solver,
                                                wall_dt=1,
                                                mode=mode)
                self.solver.step(
                    dt)  #clean this up in the future...works for now.
                post.merge_process_files(data_dir + '/final_checkpoint/',
                                         cleanup=False)
            except:
                raise
                print('cannot save final checkpoint')
            finally:
                if not no_join:
                    logger.info('beginning join operation')
                    post.merge_analysis(data_dir + 'checkpoint')

                    for key, task in analysis_tasks.items():
                        logger.info(task.base_path)
                        post.merge_analysis(task.base_path)

                logger.info(40 * "=")
                logger.info('Iterations: {:d}'.format(n_iter_loop))
                logger.info('Sim end time: {:f}'.format(self.solver.sim_time))
                logger.info('Run time: {:f} sec'.format(main_loop_time))
                logger.info('Run time: {:f} cpu-hr'.format(
                    main_loop_time / 60 / 60 *
                    self.de_domain.domain.dist.comm_cart.size))
                logger.info('iter/sec: {:f} (main loop only)'.format(
                    n_iter_loop / main_loop_time))
コード例 #14
0
ファイル: FC_multi_rxn.py プロジェクト: exoweather/polytrope
def FC_convection(Rayleigh=1e6, Prandtl=1,
                  ChemicalPrandtl=1, ChemicalReynolds=10, stiffness=1e4,
                      n_rho_cz=3.5, n_rho_rz=1, 
                      nz_cz=128, nz_rz=128,
                      nx = None,
                      width=None,
                      single_chebyshev=False,
                      rk222=False,
                      superstep=False,
                      dense=False, nz_dense=64,
                      oz=False,
                      restart=None, data_dir='./', verbose=False):
    import numpy as np
    import time
    from stratified_dynamics import multitropes
    import os
    from dedalus.core.future import FutureField
    
    initial_time = time.time()

    logger.info("Starting Dedalus script {:s}".format(sys.argv[0]))

    if oz:
        constant_Prandtl=False
        stable_top=True
        mixed_temperature_flux=True
    else:
        constant_Prandtl=True
        stable_top=False
        mixed_temperature_flux=None
        
    # Set domain
    if nx is None:
        nx = nz_cz*4
        
    if single_chebyshev:
        nz = nz_cz
        nz_list = [nz_cz]
    else:
        if dense:
            nz = nz_rz+nz_dense+nz_cz
            #nz_list = [nz_rz, int(nz_dense/2), int(nz_dense/2), nz_cz]
            nz_list = [nz_rz, nz_dense, nz_cz]
        else:
            nz = nz_rz+nz_cz
            nz_list = [nz_rz, nz_cz]

    atmosphere = multitropes.FC_multitrope_rxn(nx=nx, nz=nz_list, stiffness=stiffness, 
                                         n_rho_cz=n_rho_cz, n_rho_rz=n_rho_rz, 
                                         verbose=verbose, width=width,
                                         constant_Prandtl=constant_Prandtl,
                                         stable_top=stable_top)
    
    atmosphere.set_IVP_problem(Rayleigh, Prandtl, ChemicalPrandtl, ChemicalReynolds)
        
    atmosphere.set_BC(mixed_temperature_flux=mixed_temperature_flux)
    problem = atmosphere.get_problem()

    #atmosphere.plot_atmosphere()
    #atmosphere.plot_scaled_atmosphere()

        
    if atmosphere.domain.distributor.rank == 0:
        if not os.path.exists('{:s}/'.format(data_dir)):
            os.mkdir('{:s}/'.format(data_dir))

    if rk222:
        logger.info("timestepping using RK222")
        ts = de.timesteppers.RK222
        cfl_safety_factor = 0.2*2
    else:
        logger.info("timestepping using RK443")
        ts = de.timesteppers.RK443
        cfl_safety_factor = 0.2*4

    # Build solver
    solver = problem.build_solver(ts)

    if do_checkpointing:
        checkpoint = Checkpoint(data_dir)
        checkpoint.set_checkpoint(solver, wall_dt=1800)

    # initial conditions
    if restart is None:
        atmosphere.set_IC(solver)
    else:
        if do_checkpointing:
            logger.info("restarting from {}".format(restart))
            checkpoint.restart(restart, solver)
        else:
            logger.error("No checkpointing capability in this branch of Dedalus.  Aborting.")
            raise

    logger.info("thermal_time = {:g}, top_thermal_time = {:g}".format(atmosphere.thermal_time, atmosphere.top_thermal_time))
    
    max_dt = atmosphere.min_BV_time 
    max_dt = atmosphere.buoyancy_time*0.25

    report_cadence = 1
    output_time_cadence = 0.1*atmosphere.buoyancy_time
    solver.stop_sim_time = np.inf
    solver.stop_iteration= np.inf
    solver.stop_wall_time = 23.5*3600

    logger.info("output cadence = {:g}".format(output_time_cadence))

    analysis_tasks = atmosphere.initialize_output(solver, data_dir, sim_dt=output_time_cadence)

    
    cfl_cadence = 1
    CFL = flow_tools.CFL(solver, initial_dt=max_dt, cadence=cfl_cadence, safety=cfl_safety_factor,
                         max_change=1.5, min_change=0.5, max_dt=max_dt, threshold=0.1)

    if superstep:
        CFL_traditional = flow_tools.CFL(solver, initial_dt=max_dt, cadence=cfl_cadence, safety=cfl_safety_factor,
                                         max_change=1.5, min_change=0.5, max_dt=max_dt, threshold=0.1)

        CFL_traditional.add_velocities(('u', 'w'))
    
        vel_u = FutureField.parse('u', CFL.solver.evaluator.vars, CFL.solver.domain)
        delta_x = atmosphere.Lx/nx
        CFL.add_frequency(vel_u/delta_x)
        vel_w = FutureField.parse('w', CFL.solver.evaluator.vars, CFL.solver.domain)
        mean_delta_z_cz = atmosphere.Lz_cz/nz_cz
        CFL.add_frequency(vel_w/mean_delta_z_cz)
    else:
        CFL.add_velocities(('u', 'w'))


    
    # Flow properties
    flow = flow_tools.GlobalFlowProperty(solver, cadence=1)
    flow.add_property("Re_rms", name='Re')

    try:
        start_time = time.time()
        while solver.ok:

            dt = CFL.compute_dt()
            # advance
            solver.step(dt)

            # update lists
            if solver.iteration % report_cadence == 0:
                log_string = 'Iteration: {:5d}, Time: {:8.3e} ({:8.3e}), '.format(solver.iteration, solver.sim_time, solver.sim_time/atmosphere.buoyancy_time)
                log_string += 'dt: {:8.3e}'.format(dt)
                if superstep:
                    dt_traditional = CFL_traditional.compute_dt()
                    log_string += ' (vs {:8.3e})'.format(dt_traditional)
                log_string += ', '
                log_string += 'Re: {:8.3e}/{:8.3e}'.format(flow.grid_average('Re'), flow.max('Re'))
                logger.info(log_string)
    except:
        logger.error('Exception raised, triggering end of main loop.')
        raise
    finally:
        end_time = time.time()

        # Print statistics
        elapsed_time = end_time - start_time
        elapsed_sim_time = solver.sim_time
        N_iterations = solver.iteration 
        logger.info('main loop time: {:e}'.format(elapsed_time))
        logger.info('Iterations: {:d}'.format(N_iterations))
        logger.info('iter/sec: {:g}'.format(N_iterations/(elapsed_time)))
        logger.info('Average timestep: {:e}'.format(elapsed_sim_time / N_iterations))
        
        logger.info('beginning join operation')
        if do_checkpointing:
            logger.info(data_dir+'/checkpoint/')
            post.merge_analysis(data_dir+'/checkpoint/')

        for task in analysis_tasks:
            logger.info(analysis_tasks[task].base_path)
            post.merge_analysis(analysis_tasks[task].base_path)

        if (atmosphere.domain.distributor.rank==0):

            logger.info('main loop time: {:e}'.format(elapsed_time))
            logger.info('Iterations: {:d}'.format(N_iterations))
            logger.info('iter/sec: {:g}'.format(N_iterations/(elapsed_time)))
            logger.info('Average timestep: {:e}'.format(elapsed_sim_time / N_iterations))
 
            N_TOTAL_CPU = atmosphere.domain.distributor.comm_cart.size

            # Print statistics
            print('-' * 40)
            total_time = end_time-initial_time
            main_loop_time = end_time - start_time
            startup_time = start_time-initial_time
            n_steps = solver.iteration-1
            print('  startup time:', startup_time)
            print('main loop time:', main_loop_time)
            print('    total time:', total_time)
            print('Iterations:', solver.iteration)
            print('Average timestep:', solver.sim_time / n_steps)
            print("          N_cores, Nx, Nz, startup     main loop,   main loop/iter, main loop/iter/grid, n_cores*main loop/iter/grid")
            print('scaling:',
                  ' {:d} {:d} {:d}'.format(N_TOTAL_CPU,nx,nz),
                  ' {:8.3g} {:8.3g} {:8.3g} {:8.3g} {:8.3g}'.format(startup_time,
                                                                    main_loop_time, 
                                                                    main_loop_time/n_steps, 
                                                                    main_loop_time/n_steps/(nx*nz), 
                                                                    N_TOTAL_CPU*main_loop_time/n_steps/(nx*nz)))
            print('-' * 40)
コード例 #15
0
def FC_polytrope(  Rayleigh=1e4, Prandtl=1, aspect_ratio=4,
                   Taylor=None, theta=0,
                        nz=128, nx=None, ny=None, threeD=False, mesh=None,
                        n_rho_cz=3, epsilon=1e-4,
                        run_time=23.5, run_time_buoyancies=None, run_time_iter=np.inf,
                        fixed_T=False, fixed_flux=False, mixed_flux_T=False,
                        const_mu=True, const_kappa=True,
                        dynamic_diffusivities=False, split_diffusivities=False,
                        restart=None, start_new_files=False,
                        rk222=False, safety_factor=0.2,
                        data_dir='./', out_cadence=0.1, no_coeffs=False, no_join=False,
                        verbose=False):
    import time
    from stratified_dynamics import polytropes
    import os
    import sys
    
    initial_time = time.time()

    logger.info("Starting Dedalus script {:s}".format(sys.argv[0]))

    if nx is None:
        nx = int(np.round(nz*aspect_ratio))
    if threeD and ny is None:
        ny = nx

    if threeD:
        atmosphere = polytropes.FC_polytrope_3d(nx=nx, ny=ny, nz=nz, mesh=mesh, constant_kappa=const_kappa, constant_mu=const_mu,\
                                        epsilon=epsilon, n_rho_cz=n_rho_cz, aspect_ratio=aspect_ratio,\
                                        fig_dir=data_dir)
    else:
        if dynamic_diffusivities:
            atmosphere = polytropes.FC_polytrope_2d_kappa(nx=nx, nz=nz, constant_kappa=const_kappa, constant_mu=const_mu,\
                                        epsilon=epsilon, n_rho_cz=n_rho_cz, aspect_ratio=aspect_ratio,\
                                        fig_dir=data_dir)
        else:
            atmosphere = polytropes.FC_polytrope_2d(nx=nx, nz=nz, constant_kappa=const_kappa, constant_mu=const_mu,\
                                        epsilon=epsilon, n_rho_cz=n_rho_cz, aspect_ratio=aspect_ratio,\
                                        fig_dir=data_dir)
    if epsilon < 1e-4:
        ncc_cutoff = 1e-14
    elif epsilon > 1e-1:
        ncc_cutoff = 1e-6
    else:
        ncc_cutoff = 1e-10
    if threeD:
        atmosphere.set_IVP_problem(Rayleigh, Prandtl, Taylor=Taylor, theta=theta, ncc_cutoff=ncc_cutoff, split_diffusivities=split_diffusivities)
    else:
        atmosphere.set_IVP_problem(Rayleigh, Prandtl, ncc_cutoff=ncc_cutoff, split_diffusivities=split_diffusivities)
    
    if fixed_flux:
        atmosphere.set_BC(fixed_flux=True, stress_free=True)
    elif mixed_flux_T:
        atmosphere.set_BC(mixed_flux_temperature=True, stress_free=True)
    else:
        atmosphere.set_BC(fixed_temperature=True, stress_free=True)

    problem = atmosphere.get_problem()

    if atmosphere.domain.distributor.rank == 0:
        if not os.path.exists('{:s}/'.format(data_dir)):
            os.mkdir('{:s}/'.format(data_dir))

    if rk222:
        logger.info("timestepping using RK222")
        ts = de.timesteppers.RK222
        cfl_safety_factor = safety_factor*2
    else:
        logger.info("timestepping using RK443")
        ts = de.timesteppers.RK443
        cfl_safety_factor = safety_factor*4

    # Build solver
    solver = problem.build_solver(ts)

    #Check atmosphere
    logger.info("thermal_time = {:g}, top_thermal_time = {:g}".format(atmosphere.thermal_time,\
                                                                    atmosphere.top_thermal_time))
    logger.info("full atm HS check")
    atmosphere.check_atmosphere(make_plots = False, rho=atmosphere.get_full_rho(solver), T=atmosphere.get_full_T(solver))

    #Set up timestep defaults
    max_dt = atmosphere.buoyancy_time*0.05
    dt = max_dt/5
    if epsilon < 1e-5:
        max_dt = atmosphere.buoyancy_time*0.05
        dt     = max_dt

    if restart is None:
        mode = "overwrite"
    else:
        mode = "append"

    if do_checkpointing:
        logger.info('checkpointing in {}'.format(data_dir))
        checkpoint = Checkpoint(data_dir)

        if restart is None:
            atmosphere.set_IC(solver)
        else:
            logger.info("restarting from {}".format(restart))
            dt = checkpoint.restart(restart, solver)

        checkpoint.set_checkpoint(solver, wall_dt=checkpoint_min*60, mode=mode)
    else:
        atmosphere.set_IC(solver)
    
    if run_time_buoyancies != None:
        solver.stop_sim_time    = solver.sim_time + run_time_buoyancies*atmosphere.buoyancy_time
    else:
        solver.stop_sim_time    = 100*atmosphere.thermal_time
    
    solver.stop_iteration   = solver.iteration + run_time_iter
    solver.stop_wall_time   = run_time*3600
    report_cadence = 1
    output_time_cadence = out_cadence*atmosphere.buoyancy_time
    Hermitian_cadence = 100
    
    logger.info("stopping after {:g} time units".format(solver.stop_sim_time))
    logger.info("output cadence = {:g}".format(output_time_cadence))
    
    if no_coeffs:
        coeffs_output=False
    else:
        coeffs_output=True
    analysis_tasks = atmosphere.initialize_output(solver, data_dir, sim_dt=output_time_cadence, coeffs_output=coeffs_output,\
                                mode=mode)
    
    cfl_cadence = 1
    cfl_threshold=0.1
    CFL = flow_tools.CFL(solver, initial_dt=dt, cadence=cfl_cadence, safety=cfl_safety_factor,
                         max_change=1.5, min_change=0.5, max_dt=max_dt, threshold=cfl_threshold)
    if threeD:
        CFL.add_velocities(('u', 'v', 'w'))
    else:
        CFL.add_velocities(('u', 'w'))

    # Flow properties
    flow = flow_tools.GlobalFlowProperty(solver, cadence=1)
    flow.add_property("Re_rms", name='Re')
    if verbose:
        flow.add_property("Pe_rms", name='Pe')
        flow.add_property("Nusselt_AB17", name='Nusselt')
    
    start_iter=solver.iteration
    start_sim_time = solver.sim_time

    try:
        start_time = time.time()
        start_iter = solver.iteration
        logger.info('starting main loop')
        good_solution = True
        first_step = True
        while solver.ok and good_solution:
            dt = CFL.compute_dt()
            # advance
            solver.step(dt)

            effective_iter = solver.iteration - start_iter

            if threeD and effective_iter % Hermitian_cadence == 0:
                for field in solver.state.fields:
                    field.require_grid_space()

            # update lists
            if effective_iter % report_cadence == 0:
                Re_avg = flow.grid_average('Re')
                log_string = 'Iteration: {:5d}, Time: {:8.3e} ({:8.3e}), dt: {:8.3e}, '.format(solver.iteration-start_iter, solver.sim_time, (solver.sim_time-start_sim_time)/atmosphere.buoyancy_time, dt)
                if verbose:
                    log_string += '\n\t\tRe: {:8.5e}/{:8.5e}'.format(Re_avg, flow.max('Re'))
                    log_string += '; Pe: {:8.5e}/{:8.5e}'.format(flow.grid_average('Pe'), flow.max('Pe'))
                    log_string += '; Nu: {:8.5e}/{:8.5e}'.format(flow.grid_average('Nusselt'), flow.max('Nusselt'))
                else:
                    log_string += 'Re: {:8.3e}/{:8.3e}'.format(Re_avg, flow.max('Re'))
                logger.info(log_string)
                
            if not np.isfinite(Re_avg):
                good_solution = False
                logger.info("Terminating run.  Trapped on Reynolds = {}".format(Re_avg))
                    
            if first_step:
                if verbose:
                    import matplotlib
                    matplotlib.use('Agg')
                    import matplotlib.pyplot as plt
                    fig = plt.figure()
                    ax = fig.add_subplot(1,1,1)
                    ax.spy(solver.pencils[0].L, markersize=1, markeredgewidth=0.0)
                    fig.savefig(data_dir+"sparsity_pattern.png", dpi=1200)

                    import scipy.sparse.linalg as sla
                    LU = sla.splu(solver.pencils[0].LHS.tocsc(), permc_spec='NATURAL')
                    fig = plt.figure()
                    ax = fig.add_subplot(1,2,1)
                    ax.spy(LU.L.A, markersize=1, markeredgewidth=0.0)
                    ax = fig.add_subplot(1,2,2)
                    ax.spy(LU.U.A, markersize=1, markeredgewidth=0.0)
                    fig.savefig(data_dir+"sparsity_pattern_LU.png", dpi=1200)

                    logger.info("{} nonzero entries in LU".format(LU.nnz))
                    logger.info("{} nonzero entries in LHS".format(solver.pencils[0].LHS.tocsc().nnz))
                    logger.info("{} fill in factor".format(LU.nnz/solver.pencils[0].LHS.tocsc().nnz))
                first_step = False
                start_time = time.time()
    except:
        logger.error('Exception raised, triggering end of main loop.')
        raise
    finally:
        end_time = time.time()

        # Print statistics
        elapsed_time = end_time - start_time
        elapsed_sim_time = solver.sim_time
        N_iterations = solver.iteration-1
        logger.info('main loop time: {:e}'.format(elapsed_time))
        logger.info('Iterations: {:d}'.format(N_iterations))
        logger.info('iter/sec: {:g}'.format(N_iterations/(elapsed_time)))
        if N_iterations > 0:
            logger.info('Average timestep: {:e}'.format(elapsed_sim_time / N_iterations))
        
        if not no_join:
            logger.info('beginning join operation')
        if do_checkpointing:
            try:
                final_checkpoint = Checkpoint(data_dir, checkpoint_name='final_checkpoint')
                final_checkpoint.set_checkpoint(solver, wall_dt=1, mode="append")
                solver.step(dt) #clean this up in the future...works for now.
                post.merge_analysis(data_dir+'/final_checkpoint/')
            except:
                print('cannot save final checkpoint')
            if not no_join:
                logger.info(data_dir+'/checkpoint/')
                post.merge_analysis(data_dir+'/checkpoint/')
        if not no_join:
            for task in analysis_tasks.keys():
                logger.info(analysis_tasks[task].base_path)
                post.merge_analysis(analysis_tasks[task].base_path)

        if (atmosphere.domain.distributor.rank==0):

            logger.info('main loop time: {:e}'.format(elapsed_time))
            if start_iter > 1:
                logger.info('Iterations (this run): {:d}'.format(N_iterations - start_iter))
                logger.info('Iterations (total): {:d}'.format(N_iterations - start_iter))
            logger.info('iter/sec: {:g}'.format(N_iterations/(elapsed_time)))
            if N_iterations > 0:
                logger.info('Average timestep: {:e}'.format(elapsed_sim_time / N_iterations))
 
            N_TOTAL_CPU = atmosphere.domain.distributor.comm_cart.size

            # Print statistics
            print('-' * 40)
            total_time = end_time-initial_time
            main_loop_time = end_time - start_time
            startup_time = start_time-initial_time
            n_steps = solver.iteration-1
            print('  startup time:', startup_time)
            print('main loop time:', main_loop_time)
            print('    total time:', total_time)
            if n_steps > 0:
                print('    iterations:', n_steps)
                print(' loop sec/iter:', main_loop_time/n_steps)
                print('    average dt:', solver.sim_time/n_steps)
                print("          N_cores, Nx, Nz, startup     main loop,   main loop/iter, main loop/iter/grid, n_cores*main loop/iter/grid")
                print('scaling:',
                    ' {:d} {:d} {:d}'.format(N_TOTAL_CPU,nx,nz),
                    ' {:8.3g} {:8.3g} {:8.3g} {:8.3g} {:8.3g}'.format(startup_time,
                                                                    main_loop_time, 
                                                                    main_loop_time/n_steps, 
                                                                    main_loop_time/n_steps/(nx*nz), 
                                                                    N_TOTAL_CPU*main_loop_time/n_steps/(nx*nz)))
            print('-' * 40)