Exemplo n.º 1
0
    def get_grid(self, NFP, **kwargs):
        """Get grid for plotting.

        Parameters
        __________
        NFP : int
            number of (?)
        kwargs
            any arguments taken by LinearGrid (Default L=100, M=1, N=1)

        Returns
        _______
        LinearGrid

        """
        grid_args = {'rho':1.0, 'L':100, 'theta':0.0, 'M':1, 'zeta':0.0, 'N':1,
            'endpoint':False, 'NFP':NFP}
        for key in kwargs.keys():
            if key in grid_args.keys():
                grid_args[key] = kwargs[key]
        plot_axes = [0,1,2]
        grid_args['rho'] = self.__format_rtz__(grid_args['rho'])
        if grid_args['L'] == 1:
            plot_axes.remove(0)
        grid_args['theta'] = self.__format_rtz__(grid_args['theta'])
        if grid_args['M'] == 1:
            plot_axes.remove(1)
        grid_args['zeta'] = self.__format_rtz__(grid_args['zeta'])
        if grid_args['N'] == 1:
            plot_axes.remove(2)
        return LinearGrid(**grid_args), tuple(plot_axes)
Exemplo n.º 2
0
    def test_linear_grid(self):

        L = 3
        M = 3
        N = 3
        NFP = 1

        grid = LinearGrid(L, M, N, NFP, sym=False, endpoint=False)

        nodes = np.stack([np.array([0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1,
                                    0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1,
                                    0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1]),
                          np.array([0, 0, 0, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3,
                                    0, 0, 0, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3,
                                    0, 0, 0, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3]),
                          np.array([0, 0, 0, 0, 0, 0, 0, 0, 0,
                                    2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3,
                                    4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3])]).T

        np.testing.assert_allclose(grid.nodes, nodes, atol=1e-8)

        self.assertAlmostEqual(np.sum(grid.volumes[:, 0]*
                                      grid.volumes[:, 1]*
                                      grid.volumes[:, 2]),
                               (2*np.pi)**2/NFP)
Exemplo n.º 3
0
    def test_set_grid(self):
        """Tests the grid setter method
        """
        basis = FourierZernikeBasis(M=1, N=1)

        grid_1 = LinearGrid(L=1, M=1, N=1)
        grid_3 = LinearGrid(L=3, M=1, N=1)
        grid_5 = LinearGrid(L=5, M=1, N=1)

        transf_1 = Transform(grid_1, basis)
        transf_3 = Transform(grid_3, basis)
        transf_5 = Transform(grid_5, basis)

        transf_3.grid = grid_5
        self.assertTrue(transf_3 == transf_5)

        transf_3.grid = grid_1
        self.assertTrue(transf_3 == transf_1)
Exemplo n.º 4
0
    def test_eq(self):
        """Tests equals operator overload method
        """
        grid_1 = LinearGrid(L=11, endpoint=True)
        grid_2 = LinearGrid(M=5, N=5)
        grid_3 = ConcentricGrid(M=2, N=2)

        basis_1 = DoubleFourierSeries(M=1, N=1)
        basis_2 = FourierZernikeBasis(M=1, N=1)

        transf_11 = Transform(grid_1, basis_1)
        transf_21 = Transform(grid_2, basis_1)
        transf_31 = Transform(grid_3, basis_1)
        transf_32 = Transform(grid_3, basis_2)
        transf_32b = Transform(grid_3, basis_2)

        self.assertFalse(transf_11 == transf_21)
        self.assertTrue(transf_31 != transf_32)
        self.assertTrue(transf_32 == transf_32b)
Exemplo n.º 5
0
    def test_surface(self):
        """Tests transform of double Fourier series on a flux surface
        """
        grid = LinearGrid(M=5, N=5)
        basis = DoubleFourierSeries(M=1, N=1)
        transf = Transform(grid, basis, derivs=1)

        t = grid.nodes[:, 1]  # theta coordinates
        z = grid.nodes[:, 2]  # zeta coordinates

        correct_d0 = np.sin(t - z) + 2 * np.cos(t - z)
        correct_dt = np.cos(t - z) - 2 * np.sin(t - z)
        correct_dz = -np.cos(t - z) + 2 * np.sin(t - z)
        correct_dtz = np.sin(t - z) + 2 * np.cos(t - z)

        sin_idx_1 = np.where(
            np.all(np.array([
                np.array(basis.modes[:, 1] == -1),
                np.array(basis.modes[:, 2] == 1)
            ]),
                   axis=0))[0]
        sin_idx_2 = np.where(
            np.all(np.array([
                np.array(basis.modes[:, 1] == 1),
                np.array(basis.modes[:, 2] == -1)
            ]),
                   axis=0))[0]
        cos_idx_1 = np.where(
            np.all(np.array([
                np.array(basis.modes[:, 1] == -1),
                np.array(basis.modes[:, 2] == -1)
            ]),
                   axis=0))[0]
        cos_idx_2 = np.where(
            np.all(np.array([
                np.array(basis.modes[:, 1] == 1),
                np.array(basis.modes[:, 2] == 1)
            ]),
                   axis=0))[0]

        c = np.zeros((basis.modes.shape[0], ))
        c[sin_idx_1] = 1
        c[sin_idx_2] = -1
        c[cos_idx_1] = 2
        c[cos_idx_2] = 2

        d0 = transf.transform(c, 0, 0, 0)  # original transform
        dt = transf.transform(c, 0, 1, 0)  # theta derivative
        dz = transf.transform(c, 0, 0, 1)  # zeta derivative
        dtz = transf.transform(c, 0, 1, 1)  # mixed derivative

        np.testing.assert_allclose(d0, correct_d0, atol=1e-8)
        np.testing.assert_allclose(dt, correct_dt, atol=1e-8)
        np.testing.assert_allclose(dz, correct_dz, atol=1e-8)
        np.testing.assert_allclose(dtz, correct_dtz, atol=1e-8)
Exemplo n.º 6
0
    def test_power_series(self):
        """Tests PowerSeries evaluation
        """
        grid = LinearGrid(L=11, endpoint=True)
        r = grid.nodes[:, 0]  # rho coordinates

        correct_vals = np.array([np.ones_like(r), r, r**2]).T
        correct_ders = np.array([np.zeros_like(r), np.ones_like(r), 2 * r]).T

        basis = PowerSeries(L=2)
        values = basis.evaluate(grid.nodes, derivatives=np.array([0, 0, 0]))
        derivs = basis.evaluate(grid.nodes, derivatives=np.array([1, 0, 0]))

        np.testing.assert_allclose(values, correct_vals, atol=1e-8)
        np.testing.assert_allclose(derivs, correct_ders, atol=1e-8)
Exemplo n.º 7
0
    def test_obj_fxn_types(self):
        """test the correct objective function is returned for 'force', 'accel', and unimplemented"""
        RZ_grid = ConcentricGrid(M=2, N=0)
        L_grid = LinearGrid(M=2, N=1)
        RZ_basis = FourierZernikeBasis(M=2, N=0)
        L_basis = DoubleFourierSeries(M=2, N=0)
        PI_basis = PowerSeries(L=3)
        RZ_transform = Transform(RZ_grid, RZ_basis)
        RZ1_transform = Transform(L_grid, RZ_basis)
        L_transform = Transform(L_grid, L_basis)
        PI_transform = Transform(RZ_grid, PI_basis)

        errr_mode = 'force'
        obj_fun = ObjectiveFunctionFactory.get_equil_obj_fun(
            errr_mode,
            R_transform=RZ_transform,
            Z_transform=RZ_transform,
            R1_transform=RZ1_transform,
            Z1_transform=RZ1_transform,
            L_transform=L_transform,
            P_transform=PI_transform,
            I_transform=PI_transform)
        self.assertIsInstance(obj_fun, ForceErrorNodes)

        errr_mode = 'accel'
        obj_fun = ObjectiveFunctionFactory.get_equil_obj_fun(
            errr_mode,
            R_transform=RZ_transform,
            Z_transform=RZ_transform,
            R1_transform=RZ1_transform,
            Z1_transform=RZ1_transform,
            L_transform=L_transform,
            P_transform=PI_transform,
            I_transform=PI_transform)
        self.assertIsInstance(obj_fun, AccelErrorSpectral)

        # test unimplemented errr_mode
        with self.assertRaises(ValueError):
            errr_mode = 'not implemented'
            obj_fun = ObjectiveFunctionFactory.get_equil_obj_fun(
                errr_mode,
                R_transform=RZ_transform,
                Z_transform=RZ_transform,
                R1_transform=RZ1_transform,
                Z1_transform=RZ1_transform,
                L_transform=L_transform,
                P_transform=PI_transform,
                I_transform=PI_transform)
Exemplo n.º 8
0
    def test_transform_order_error(self):
        """Tests error handling with transform method
        """
        grid = LinearGrid(L=11, endpoint=True)
        basis = PowerSeries(L=2)
        transf = Transform(grid, basis, derivs=0)

        # invalid derivative orders
        with self.assertRaises(ValueError):
            c = np.array([1, 2, 3])
            transf.transform(c, 0, 0, 1)

        # incompatible number of coefficients
        with self.assertRaises(ValueError):
            c = np.array([1, 2])
            transf.transform(c, 0, 0, 0)
Exemplo n.º 9
0
    def test_profile(self):
        """Tests transform of power series on a radial profile
        """
        grid = LinearGrid(L=11, endpoint=True)
        basis = PowerSeries(L=2)
        transf = Transform(grid, basis, derivs=1)

        x = grid.nodes[:, 0]
        c = np.array([-1, 2, 1])

        values = transf.transform(c, 0, 0, 0)
        derivs = transf.transform(c, 1, 0, 0)

        correct_vals = c[0] + c[1] * x + c[2] * x**2
        correct_ders = c[1] + c[2] * 2 * x

        np.testing.assert_allclose(values, correct_vals, atol=1e-8)
        np.testing.assert_allclose(derivs, correct_ders, atol=1e-8)
Exemplo n.º 10
0
    def test_double_fourier(self):
        """Tests DoubleFourierSeries evaluation
        """
        grid = LinearGrid(M=5, N=5)
        t = grid.nodes[:, 1]  # theta coordinates
        z = grid.nodes[:, 2]  # zeta coordinates

        correct_vals = np.array([
            np.sin(t) * np.sin(z),
            np.sin(z),
            np.cos(t) * np.sin(z),
            np.sin(t),
            np.ones_like(t),
            np.cos(t),
            np.sin(t) * np.cos(z),
            np.cos(z),
            np.cos(t) * np.cos(z)
        ]).T

        basis = DoubleFourierSeries(M=1, N=1)
        values = basis.evaluate(grid.nodes, derivatives=np.array([0, 0, 0]))

        np.testing.assert_allclose(values, correct_vals, atol=1e-8)
Exemplo n.º 11
0
def plot_vmec_comparison(vmec_data, equil):
    """Plots comparison of VMEC and DESC solutions

    Parameters
    ----------
    vmec_data : dict
        dictionary of VMEC solution quantities.
    equil : dict
        dictionary of DESC equilibrium solution quantities.

    Returns
    -------

    """

    cR = equil.cR
    cZ = equil.cZ
    NFP = equil.NFP
    R_basis = equil.R_basis
    Z_basis = equil.Z_basis

    Nr = 8
    Nt = 360
    if np.max(R_basis.modes[:, 2]) == 0:
        Nz = 1
        rows = 1
    else:
        Nz = 6
        rows = 2

    Nr_vmec = vmec_data['rmnc'].shape[0]-1
    s_idx = Nr_vmec % np.floor(Nr_vmec/(Nr-1))
    idxes = np.linspace(s_idx, Nr_vmec, Nr).astype(int)
    if s_idx != 0:
        idxes = np.pad(idxes, (1, 0), mode='constant')
        Nr += 1
    rho = np.sqrt(idxes/Nr_vmec)
    grid = LinearGrid(L=Nr, M=Nt, N=Nz, NFP=NFP, rho=rho, endpoint=True)
    R_transf = Transform(grid, R_basis)
    Z_transf = Transform(grid, Z_basis)

    R_desc = R_transf.transform(cR).reshape((Nr, Nt, Nz), order='F')
    Z_desc = Z_transf.transform(cZ).reshape((Nr, Nt, Nz), order='F')

    R_vmec, Z_vmec = vmec_interpolate(
        vmec_data['rmnc'][idxes], vmec_data['zmns'][idxes], vmec_data['xm'], vmec_data['xn'],
        np.unique(grid.nodes[:, 1]), np.unique(grid.nodes[:, 2]))

    plt.figure()
    for k in range(Nz):
        ax = plt.subplot(rows, int(Nz/rows), k+1)
        ax.plot(R_vmec[0, 0, k], Z_vmec[0, 0, k], 'bo')
        s_vmec = ax.plot(R_vmec[:, :, k].T, Z_vmec[:, :, k].T, 'b-')
        ax.plot(R_desc[0, 0, k], Z_desc[0, 0, k], 'ro')
        s_desc = ax.plot(R_desc[:, :, k].T, Z_desc[:, :, k].T, 'r--')
        ax.axis('equal')
        ax.set_xlabel('R')
        ax.set_ylabel('Z')
        if k == 0:
            s_vmec[0].set_label('VMEC')
            s_desc[0].set_label('DESC')
            ax.legend(fontsize='xx-small')
    plt.show()
Exemplo n.º 12
0
def plot_comparison(equil0, equil1, label0='x0', label1='x1', **kwargs):
    """Plots force balance error

    Parameters
    ----------
    equil0, equil1 : dict
        dictionary of two equilibrium solution quantities
    label0, label1 : str
        labels for each equilibria
    **kwargs :
        additional plot formatting parameters

    Returns
    -------

    """

    cR0 = equil0.cR
    cZ0 = equil0.cZ
    NFP0 = equil0.NFP
    R_basis0 = equil0.R_basis
    Z_basis0 = equil0.Z_basis

    cR1 = equil1.cR
    cZ1 = equil1.cZ
    NFP1 = equil1.NFP
    R_basis1 = equil1.R_basis
    Z_basis1 = equil1.Z_basis

    if NFP0 == NFP1:
        NFP = NFP0
    else:
        raise ValueError(
            TextColors.FAIL + "NFP must be the same for both solutions" + TextColors.ENDC)

    if max(np.max(R_basis0.modes[:, 2]), np.max(R_basis1.modes[:, 2])) == 0:
        Nz = 1
        rows = 1
    else:
        Nz = 6
        rows = 2

    Nr = kwargs.get('Nr', 8)
    Nt = kwargs.get('Nt', 13)

    NNr = 100
    NNt = 360

    # constant rho surfaces
    grid_r = LinearGrid(L=Nr, M=NNt, N=Nz, NFP=NFP, endpoint=True)
    R_transf_0r = Transform(grid_r, R_basis0)
    Z_transf_0r = Transform(grid_r, Z_basis0)
    R_transf_1r = Transform(grid_r, R_basis1)
    Z_transf_1r = Transform(grid_r, Z_basis1)

    # constant theta surfaces
    grid_t = LinearGrid(L=NNr, M=Nt, N=Nz, NFP=NFP, endpoint=True)
    R_transf_0t = Transform(grid_t, R_basis0)
    Z_transf_0t = Transform(grid_t, Z_basis0)
    R_transf_1t = Transform(grid_t, R_basis1)
    Z_transf_1t = Transform(grid_t, Z_basis1)

    R0r = R_transf_0r.transform(cR0).reshape((Nr, NNt, Nz), order='F')
    Z0r = Z_transf_0r.transform(cZ0).reshape((Nr, NNt, Nz), order='F')
    R1r = R_transf_1r.transform(cR1).reshape((Nr, NNt, Nz), order='F')
    Z1r = Z_transf_1r.transform(cZ1).reshape((Nr, NNt, Nz), order='F')

    R0v = R_transf_0t.transform(cR0).reshape((NNr, Nt, Nz), order='F')
    Z0v = Z_transf_0t.transform(cZ0).reshape((NNr, Nt, Nz), order='F')
    R1v = R_transf_1t.transform(cR1).reshape((NNr, Nt, Nz), order='F')
    Z1v = Z_transf_1t.transform(cZ1).reshape((NNr, Nt, Nz), order='F')

    plt.figure()
    for k in range(Nz):
        ax = plt.subplot(rows, int(Nz/rows), k+1)

        ax.plot(R0r[0, 0, k], Z0r[0, 0, k], 'bo')
        s0 = ax.plot(R0r[:, :, k].T, Z0r[:, :, k].T, 'b-')
        ax.plot(R0v[:, :, k], Z0v[:, :, k], 'b:')

        ax.plot(R1r[0, 0, k], Z1r[0, 0, k], 'ro')
        s1 = ax.plot(R1r[:, :, k].T, Z1r[:, :, k].T, 'r-')
        ax.plot(R1v[:, :, k], Z1v[:, :, k], 'r:')

        ax.axis('equal')
        ax.set_xlabel('R')
        ax.set_ylabel('Z')
        if k == 0:
            s0[0].set_label(label0)
            s1[0].set_label(label1)
            ax.legend(fontsize='xx-small')
    plt.show()
Exemplo n.º 13
0
def vmec_error(equil, vmec_data, Nt=8, Nz=4):
    """Computes error in SFL coordinates compared to VMEC solution

    Parameters
    ----------
    equil : dict
        dictionary of DESC equilibrium parameters
    vmec_data : dict
        dictionary of VMEC equilibrium parameters
    Nt : int
        number of poloidal angles to sample (Default value = 8)
    Nz : int
        number of toroidal angles to sample (Default value = 8)

    Returns
    -------
    err : float
        average Euclidean distance between VMEC and DESC sample points

    """
    ns = np.size(vmec_data['psi'])
    rho = np.sqrt(vmec_data['psi'])
    grid = LinearGrid(L=ns, M=Nt, N=Nz, NFP=equil.NFP, rho=rho)
    R_basis = equil.R_basis
    Z_basis = equil.Z_basis
    R_transf = Transform(grid, R_basis)
    Z_transf = Transform(grid, Z_basis)
    vartheta = np.unique(grid.nodes[:, 1])
    phi = np.unique(grid.nodes[:, 2])

    R_desc = R_transf.transform(equil.cR).reshape((ns, Nt, Nz), order='F')
    Z_desc = Z_transf.transform(equil.cZ).reshape((ns, Nt, Nz), order='F')

    print('Interpolating VMEC solution to sfl coordinates')
    R_vmec = np.zeros((ns, Nt, Nz))
    Z_vmec = np.zeros((ns, Nt, Nz))
    for k in range(Nz):  # toroidal angle
        for i in range(ns):  # flux surface
            theta = np.zeros((Nt, ))
            for j in range(Nt):  # poloidal angle
                f0 = sfl_err(np.array([0]), vartheta[j], phi[k], vmec_data, i)
                f2pi = sfl_err(np.array([2 * np.pi]), vartheta[j], phi[k],
                               vmec_data, i)
                flag = (sign(f0) + sign(f2pi)) / 2
                args = (vartheta[j], phi[k], vmec_data, i, flag)
                t = fsolve(sfl_err, vartheta[j], args=args)
                if flag != 0:
                    t = np.remainder(t + np.pi, 2 * np.pi)
                theta[j] = t  # theta angle that corresponds to vartheta[j]
            R_vmec[i, :, k] = vmec_transf(vmec_data['rmnc'][i, :],
                                          vmec_data['xm'],
                                          vmec_data['xn'],
                                          theta,
                                          phi[k],
                                          trig='cos').flatten()
            Z_vmec[i, :, k] = vmec_transf(vmec_data['zmns'][i, :],
                                          vmec_data['xm'],
                                          vmec_data['xn'],
                                          theta,
                                          phi[k],
                                          trig='sin').flatten()
            if not vmec_data['sym']:
                R_vmec[i, :, k] += vmec_transf(vmec_data['rmns'][i, :],
                                               vmec_data['xm'],
                                               vmec_data['xn'],
                                               theta,
                                               phi[k],
                                               trig='sin').flatten()
                Z_vmec[i, :, k] += vmec_transf(vmec_data['zmnc'][i, :],
                                               vmec_data['xm'],
                                               vmec_data['xn'],
                                               theta,
                                               phi[k],
                                               trig='cos').flatten()
        print('{}%'.format((k + 1) / Nz * 100))

    return np.mean(np.sqrt((R_vmec - R_desc)**2 + (Z_vmec - Z_desc)**2))
Exemplo n.º 14
0
def solve_eq_continuation(inputs, checkpoint_filename=None, device=None):
    """Solves for an equilibrium by continuation method

    Follows this procedure to solve the equilibrium:
        1. Creates an initial guess from the given inputs
        2. Optimizes the equilibrium's flux surfaces by minimizing
            the given objective function.
        3. Step up to higher resolution and perturb the previous solution
        4. Repeat 2 and 3 until at desired resolution

    Parameters
    ----------
    inputs : dict
        dictionary with input parameters defining problem setup and solver options
    checkpoint_filename : str or path-like
        file to save checkpoint data (Default value = None)
    device : jax.device or None
        device handle to JIT compile to (Default value = None)

    Returns
    -------
    equil_fam : EquilibriaFamily
        Container object that contains a list of the intermediate solutions,
            as well as the final solution, stored as Equilibrium objects
    timer : Timer
        Timer object containing timing data for individual iterations

    """
    timer = Timer()
    timer.start("Total time")

    stell_sym = inputs['stell_sym']
    NFP = inputs['NFP']
    Psi_lcfs = inputs['Psi_lcfs']
    M = inputs['Mpol']  # arr
    N = inputs['Ntor']  # arr
    delta_lm = inputs['delta_lm']  # arr
    Mnodes = inputs['Mnodes']  # arr
    Nnodes = inputs['Nnodes']  # arr
    bdry_ratio = inputs['bdry_ratio']  # arr
    pres_ratio = inputs['pres_ratio']  # arr
    zeta_ratio = inputs['zeta_ratio']  # arr
    errr_ratio = inputs['errr_ratio']  # arr
    pert_order = inputs['pert_order']  # arr
    ftol = inputs['ftol']  # arr
    xtol = inputs['xtol']  # arr
    gtol = inputs['gtol']  # arr
    nfev = inputs['nfev']  # arr
    optim_method = inputs['optim_method']
    errr_mode = inputs['errr_mode']
    bdry_mode = inputs['bdry_mode']
    zern_mode = inputs['zern_mode']
    node_mode = inputs['node_mode']
    cP = inputs['cP']
    cI = inputs['cI']
    axis = inputs['axis']
    bdry = inputs['bdry']
    verbose = inputs['verbose']

    if checkpoint_filename is not None:
        checkpoint = True
        checkpoint_file = Checkpoint(checkpoint_filename, write_ascii=True)
    else:
        checkpoint = False

    if stell_sym:
        R_sym = Tristate(True)
        Z_sym = Tristate(False)
        L_sym = Tristate(False)
    else:
        R_sym = Tristate(None)
        Z_sym = Tristate(None)
        L_sym = Tristate(None)

    arr_len = M.size
    for ii in range(arr_len):

        if verbose > 0:
            print("================")
            print("Step {}/{}".format(ii + 1, arr_len))
            print("================")
            print("Spectral resolution (M,N,delta_lm)=({},{},{})".format(
                M[ii], N[ii], delta_lm[ii]))
            print("Node resolution (M,N)=({},{})".format(
                Mnodes[ii], Nnodes[ii]))
            print("Boundary ratio = {}".format(bdry_ratio[ii]))
            print("Pressure ratio = {}".format(pres_ratio[ii]))
            print("Zeta ratio = {}".format(zeta_ratio[ii]))
            print("Error ratio = {}".format(errr_ratio[ii]))
            print("Perturbation Order = {}".format(pert_order[ii]))
            print("Function tolerance = {}".format(ftol[ii]))
            print("Gradient tolerance = {}".format(gtol[ii]))
            print("State vector tolerance = {}".format(xtol[ii]))
            print("Max function evaluations = {}".format(nfev[ii]))
            print("================")

        # initial solution
        # at initial soln, must: create bases, create grids, create transforms
        if ii == 0:
            timer.start("Iteration {} total".format(ii + 1))

            inputs_ii = {
                'L': delta_lm[ii],
                'M': M[ii],
                'N': N[ii],
                'cP': cP * pres_ratio[ii],
                'cI': cI,
                'Psi': Psi_lcfs,
                'NFP': NFP,
                'bdry': bdry,
                'sym': stell_sym,
                'index': zern_mode,
                'bdry_mode': bdry_mode,
                'bdry_ratio': bdry_ratio[ii],
                'axis': axis,
                'output_path': checkpoint_filename
            }
            timer.start("Transform precomputation")
            if verbose > 0:
                print("Precomputing Transforms")
            equil_fam = EquilibriaFamily(inputs=inputs_ii)
            # Get initial Equilibrium from equil_fam
            equil = equil_fam[ii]

            x = equil.x  # initial state vector
            # bases (extracted from Equilibrium)
            R_basis, Z_basis, L_basis, P_basis, I_basis =   equil.R_basis, \
                                                            equil.Z_basis, \
                                                            equil.L_basis, \
                                                            equil.P_basis, \
                                                            equil.I_basis

            # grids
            RZ_grid = ConcentricGrid(Mnodes[ii],
                                     Nnodes[ii],
                                     NFP=NFP,
                                     sym=stell_sym,
                                     axis=False,
                                     index=zern_mode,
                                     surfs=node_mode)
            L_grid = LinearGrid(M=Mnodes[ii],
                                N=2 * Nnodes[ii] + 1,
                                NFP=NFP,
                                sym=stell_sym)

            # transforms
            R_transform = Transform(RZ_grid, R_basis, derivs=3)
            Z_transform = Transform(RZ_grid, Z_basis, derivs=3)
            R1_transform = Transform(L_grid, R_basis)
            Z1_transform = Transform(L_grid, Z_basis)
            L_transform = Transform(L_grid, L_basis, derivs=0)
            P_transform = Transform(RZ_grid, P_basis, derivs=1)
            I_transform = Transform(RZ_grid, I_basis, derivs=1)

            timer.stop("Transform precomputation")
            if verbose > 1:
                timer.disp("Transform precomputation")

        # continuing from previous solution
        else:
            # change grids
            if Mnodes[ii] != Mnodes[ii - 1] or Nnodes[ii] != Nnodes[ii - 1]:
                RZ_grid = ConcentricGrid(Mnodes[ii],
                                         Nnodes[ii],
                                         NFP=NFP,
                                         sym=stell_sym,
                                         axis=False,
                                         index=zern_mode,
                                         surfs=node_mode)
                L_grid = LinearGrid(M=Mnodes[ii],
                                    N=2 * Nnodes[ii] + 1,
                                    NFP=NFP,
                                    sym=stell_sym)

            # change bases
            if M[ii] != M[ii - 1] or N[ii] != N[
                    ii - 1] or delta_lm[ii] != delta_lm[ii - 1]:
                equil.change_resolution(
                    L=delta_lm[ii], M=M[ii],
                    N=N[ii])  # update equilibrium bases to the new resolutions
                R_basis, Z_basis, L_basis = equil.R_basis, equil.Z_basis, equil.L_basis
                x = equil.x

            # change transform matrices
            timer.start("Iteration {} changing resolution".format(ii + 1))
            if verbose > 0:
                print(
                    "Changing node resolution from (Mnodes,Nnodes) = ({},{}) to ({},{})"
                    .format(Mnodes[ii - 1], Nnodes[ii - 1], Mnodes[ii],
                            Nnodes[ii]))
                print(
                    "Changing spectral resolution from (L,M,N) = ({},{},{}) to ({},{},{})"
                    .format(delta_lm[ii - 1], M[ii - 1], N[ii - 1],
                            delta_lm[ii], M[ii], N[ii]))

            R_transform.change_resolution(grid=RZ_grid, basis=R_basis)
            Z_transform.change_resolution(grid=RZ_grid, basis=Z_basis)
            R1_transform.change_resolution(grid=L_grid, basis=R_basis)
            Z1_transform.change_resolution(grid=L_grid, basis=Z_basis)
            L_transform.change_resolution(grid=L_grid, basis=L_basis)
            P_transform.change_resolution(grid=RZ_grid)
            I_transform.change_resolution(grid=RZ_grid)
            timer.stop("Iteration {} changing resolution".format(ii + 1))
            if verbose > 1:
                timer.disp("Iteration {} changing resolution".format(ii + 1))

            # continuation parameters
            delta_bdry = bdry_ratio[ii] - bdry_ratio[ii - 1]
            delta_pres = pres_ratio[ii] - pres_ratio[ii - 1]
            delta_zeta = zeta_ratio[ii] - zeta_ratio[ii - 1]
            deltas = np.array([delta_bdry, delta_pres, delta_zeta])

            # need a non-scalar objective function to do the perturbations
            obj_fun = ObjectiveFunctionFactory.get_equil_obj_fun(
                errr_mode,
                scalar=False,
                R_transform=R_transform,
                Z_transform=Z_transform,
                R1_transform=R1_transform,
                Z1_transform=Z1_transform,
                L_transform=L_transform,
                P_transform=P_transform,
                I_transform=I_transform)
            equil_obj = obj_fun.compute
            callback = obj_fun.callback
            args = (equil.cRb, equil.cZb, equil.cP, equil.cI, equil.Psi,
                    bdry_ratio[ii - 1], pres_ratio[ii - 1], zeta_ratio[ii - 1],
                    errr_ratio[ii - 1])

            # TODO: should probably perturb before expanding resolution
            # perturbations
            if np.any(deltas):
                if verbose > 1:
                    print("Perturbing equilibrium")
                x, timer = perturb_continuation_params(x, equil_obj, deltas,
                                                       args, pert_order[ii],
                                                       verbose, timer)

        # equilibrium objective function
        if optim_method in ['bfgs']:
            scalar = True
        else:
            scalar = False
        obj_fun = ObjectiveFunctionFactory.get_equil_obj_fun(
            errr_mode,
            scalar=scalar,
            R_transform=R_transform,
            Z_transform=Z_transform,
            R1_transform=R1_transform,
            Z1_transform=Z1_transform,
            L_transform=L_transform,
            P_transform=P_transform,
            I_transform=I_transform)
        equil_obj = obj_fun.compute
        callback = obj_fun.callback
        args = (equil.cRb, equil.cZb, equil.cP, equil.cI, equil.Psi,
                bdry_ratio[ii - 1], pres_ratio[ii - 1], zeta_ratio[ii - 1],
                errr_ratio[ii - 1])

        if use_jax:
            if optim_method in ['bfgs']:
                jac = AutoDiffJacobian(equil_obj, argnum=0, mode='grad')
            else:
                jac = AutoDiffJacobian(equil_obj, argnum=0, mode='fwd')
            if verbose > 0:
                print("Compiling objective function")
            if device is None:
                import jax
                device = jax.devices()[0]
            equil_obj_jit = jit(equil_obj, static_argnums=(), device=device)
            jac_obj_jit = jit(jac.compute, device=device)
            timer.start("Iteration {} compilation".format(ii + 1))
            f0 = equil_obj_jit(x, *args)
            J0 = jac_obj_jit(x, *args)
            timer.stop("Iteration {} compilation".format(ii + 1))
            if verbose > 1:
                timer.disp("Iteration {} compilation".format(ii + 1))
        else:
            equil_obj_jit = equil_obj
            jac_obj_jit = '2-point'
        if verbose > 0:
            print("Starting optimization")

        x_init = x
        timer.start("Iteration {} solution".format(ii + 1))
        if optim_method in ['bfgs']:
            out = scipy.optimize.minimize(equil_obj_jit,
                                          x0=x_init,
                                          args=args,
                                          method=optim_method,
                                          jac=jac_obj_jit,
                                          tol=gtol[ii],
                                          options={
                                              'maxiter': nfev[ii],
                                              'disp': verbose
                                          })

        elif optim_method in ['trf', 'lm', 'dogleg']:
            out = scipy.optimize.least_squares(equil_obj_jit,
                                               x0=x_init,
                                               args=args,
                                               jac=jac_obj_jit,
                                               method=optim_method,
                                               x_scale='jac',
                                               ftol=ftol[ii],
                                               xtol=xtol[ii],
                                               gtol=gtol[ii],
                                               max_nfev=nfev[ii],
                                               verbose=verbose)
        else:
            raise NotImplementedError(
                TextColors.FAIL +
                "optim_method must be one of 'bfgs', 'trf', 'lm', 'dogleg'" +
                TextColors.ENDC)

        timer.stop("Iteration {} solution".format(ii + 1))

        equil.x = out['x']

        equil_fam.append(copy.deepcopy(equil))

        if verbose > 1:
            timer.disp("Iteration {} solution".format(ii + 1))
            timer.pretty_print(
                "Iteration {} avg time per step".format(ii + 1),
                timer["Iteration {} solution".format(ii + 1)] / out['nfev'])
        if verbose > 0:
            print("Start of Step {}:".format(ii + 1))
            callback(x_init, *args)
            print("End of Step {}:".format(ii + 1))
            callback(x, *args)

        if checkpoint:
            if verbose > 0:
                print('Saving latest iteration')
            equil_fam.save()

        if not is_nested(equil.cR, equil.cZ, equil.R_basis, equil.Z_basis):
            warnings.warn(
                TextColors.WARNING +
                'WARNING: Flux surfaces are no longer nested, exiting early.' +
                'Consider increasing errr_ratio or taking smaller perturbation steps'
                + TextColors.ENDC)
            break

    timer.stop("Total time")
    print('====================')
    print('Done')
    if verbose > 1:
        timer.disp("Total time")
    if checkpoint_filename is not None:
        print('Output written to {}'.format(checkpoint_filename))
    print('====================')

    return equil_fam, timer