def get_grid(self, NFP, **kwargs): """Get grid for plotting. Parameters __________ NFP : int number of (?) kwargs any arguments taken by LinearGrid (Default L=100, M=1, N=1) Returns _______ LinearGrid """ grid_args = {'rho':1.0, 'L':100, 'theta':0.0, 'M':1, 'zeta':0.0, 'N':1, 'endpoint':False, 'NFP':NFP} for key in kwargs.keys(): if key in grid_args.keys(): grid_args[key] = kwargs[key] plot_axes = [0,1,2] grid_args['rho'] = self.__format_rtz__(grid_args['rho']) if grid_args['L'] == 1: plot_axes.remove(0) grid_args['theta'] = self.__format_rtz__(grid_args['theta']) if grid_args['M'] == 1: plot_axes.remove(1) grid_args['zeta'] = self.__format_rtz__(grid_args['zeta']) if grid_args['N'] == 1: plot_axes.remove(2) return LinearGrid(**grid_args), tuple(plot_axes)
def test_linear_grid(self): L = 3 M = 3 N = 3 NFP = 1 grid = LinearGrid(L, M, N, NFP, sym=False, endpoint=False) nodes = np.stack([np.array([0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1]), np.array([0, 0, 0, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 0, 0, 0, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 0, 0, 0, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3]), np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 2*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3, 4*np.pi/3])]).T np.testing.assert_allclose(grid.nodes, nodes, atol=1e-8) self.assertAlmostEqual(np.sum(grid.volumes[:, 0]* grid.volumes[:, 1]* grid.volumes[:, 2]), (2*np.pi)**2/NFP)
def test_set_grid(self): """Tests the grid setter method """ basis = FourierZernikeBasis(M=1, N=1) grid_1 = LinearGrid(L=1, M=1, N=1) grid_3 = LinearGrid(L=3, M=1, N=1) grid_5 = LinearGrid(L=5, M=1, N=1) transf_1 = Transform(grid_1, basis) transf_3 = Transform(grid_3, basis) transf_5 = Transform(grid_5, basis) transf_3.grid = grid_5 self.assertTrue(transf_3 == transf_5) transf_3.grid = grid_1 self.assertTrue(transf_3 == transf_1)
def test_eq(self): """Tests equals operator overload method """ grid_1 = LinearGrid(L=11, endpoint=True) grid_2 = LinearGrid(M=5, N=5) grid_3 = ConcentricGrid(M=2, N=2) basis_1 = DoubleFourierSeries(M=1, N=1) basis_2 = FourierZernikeBasis(M=1, N=1) transf_11 = Transform(grid_1, basis_1) transf_21 = Transform(grid_2, basis_1) transf_31 = Transform(grid_3, basis_1) transf_32 = Transform(grid_3, basis_2) transf_32b = Transform(grid_3, basis_2) self.assertFalse(transf_11 == transf_21) self.assertTrue(transf_31 != transf_32) self.assertTrue(transf_32 == transf_32b)
def test_surface(self): """Tests transform of double Fourier series on a flux surface """ grid = LinearGrid(M=5, N=5) basis = DoubleFourierSeries(M=1, N=1) transf = Transform(grid, basis, derivs=1) t = grid.nodes[:, 1] # theta coordinates z = grid.nodes[:, 2] # zeta coordinates correct_d0 = np.sin(t - z) + 2 * np.cos(t - z) correct_dt = np.cos(t - z) - 2 * np.sin(t - z) correct_dz = -np.cos(t - z) + 2 * np.sin(t - z) correct_dtz = np.sin(t - z) + 2 * np.cos(t - z) sin_idx_1 = np.where( np.all(np.array([ np.array(basis.modes[:, 1] == -1), np.array(basis.modes[:, 2] == 1) ]), axis=0))[0] sin_idx_2 = np.where( np.all(np.array([ np.array(basis.modes[:, 1] == 1), np.array(basis.modes[:, 2] == -1) ]), axis=0))[0] cos_idx_1 = np.where( np.all(np.array([ np.array(basis.modes[:, 1] == -1), np.array(basis.modes[:, 2] == -1) ]), axis=0))[0] cos_idx_2 = np.where( np.all(np.array([ np.array(basis.modes[:, 1] == 1), np.array(basis.modes[:, 2] == 1) ]), axis=0))[0] c = np.zeros((basis.modes.shape[0], )) c[sin_idx_1] = 1 c[sin_idx_2] = -1 c[cos_idx_1] = 2 c[cos_idx_2] = 2 d0 = transf.transform(c, 0, 0, 0) # original transform dt = transf.transform(c, 0, 1, 0) # theta derivative dz = transf.transform(c, 0, 0, 1) # zeta derivative dtz = transf.transform(c, 0, 1, 1) # mixed derivative np.testing.assert_allclose(d0, correct_d0, atol=1e-8) np.testing.assert_allclose(dt, correct_dt, atol=1e-8) np.testing.assert_allclose(dz, correct_dz, atol=1e-8) np.testing.assert_allclose(dtz, correct_dtz, atol=1e-8)
def test_power_series(self): """Tests PowerSeries evaluation """ grid = LinearGrid(L=11, endpoint=True) r = grid.nodes[:, 0] # rho coordinates correct_vals = np.array([np.ones_like(r), r, r**2]).T correct_ders = np.array([np.zeros_like(r), np.ones_like(r), 2 * r]).T basis = PowerSeries(L=2) values = basis.evaluate(grid.nodes, derivatives=np.array([0, 0, 0])) derivs = basis.evaluate(grid.nodes, derivatives=np.array([1, 0, 0])) np.testing.assert_allclose(values, correct_vals, atol=1e-8) np.testing.assert_allclose(derivs, correct_ders, atol=1e-8)
def test_obj_fxn_types(self): """test the correct objective function is returned for 'force', 'accel', and unimplemented""" RZ_grid = ConcentricGrid(M=2, N=0) L_grid = LinearGrid(M=2, N=1) RZ_basis = FourierZernikeBasis(M=2, N=0) L_basis = DoubleFourierSeries(M=2, N=0) PI_basis = PowerSeries(L=3) RZ_transform = Transform(RZ_grid, RZ_basis) RZ1_transform = Transform(L_grid, RZ_basis) L_transform = Transform(L_grid, L_basis) PI_transform = Transform(RZ_grid, PI_basis) errr_mode = 'force' obj_fun = ObjectiveFunctionFactory.get_equil_obj_fun( errr_mode, R_transform=RZ_transform, Z_transform=RZ_transform, R1_transform=RZ1_transform, Z1_transform=RZ1_transform, L_transform=L_transform, P_transform=PI_transform, I_transform=PI_transform) self.assertIsInstance(obj_fun, ForceErrorNodes) errr_mode = 'accel' obj_fun = ObjectiveFunctionFactory.get_equil_obj_fun( errr_mode, R_transform=RZ_transform, Z_transform=RZ_transform, R1_transform=RZ1_transform, Z1_transform=RZ1_transform, L_transform=L_transform, P_transform=PI_transform, I_transform=PI_transform) self.assertIsInstance(obj_fun, AccelErrorSpectral) # test unimplemented errr_mode with self.assertRaises(ValueError): errr_mode = 'not implemented' obj_fun = ObjectiveFunctionFactory.get_equil_obj_fun( errr_mode, R_transform=RZ_transform, Z_transform=RZ_transform, R1_transform=RZ1_transform, Z1_transform=RZ1_transform, L_transform=L_transform, P_transform=PI_transform, I_transform=PI_transform)
def test_transform_order_error(self): """Tests error handling with transform method """ grid = LinearGrid(L=11, endpoint=True) basis = PowerSeries(L=2) transf = Transform(grid, basis, derivs=0) # invalid derivative orders with self.assertRaises(ValueError): c = np.array([1, 2, 3]) transf.transform(c, 0, 0, 1) # incompatible number of coefficients with self.assertRaises(ValueError): c = np.array([1, 2]) transf.transform(c, 0, 0, 0)
def test_profile(self): """Tests transform of power series on a radial profile """ grid = LinearGrid(L=11, endpoint=True) basis = PowerSeries(L=2) transf = Transform(grid, basis, derivs=1) x = grid.nodes[:, 0] c = np.array([-1, 2, 1]) values = transf.transform(c, 0, 0, 0) derivs = transf.transform(c, 1, 0, 0) correct_vals = c[0] + c[1] * x + c[2] * x**2 correct_ders = c[1] + c[2] * 2 * x np.testing.assert_allclose(values, correct_vals, atol=1e-8) np.testing.assert_allclose(derivs, correct_ders, atol=1e-8)
def test_double_fourier(self): """Tests DoubleFourierSeries evaluation """ grid = LinearGrid(M=5, N=5) t = grid.nodes[:, 1] # theta coordinates z = grid.nodes[:, 2] # zeta coordinates correct_vals = np.array([ np.sin(t) * np.sin(z), np.sin(z), np.cos(t) * np.sin(z), np.sin(t), np.ones_like(t), np.cos(t), np.sin(t) * np.cos(z), np.cos(z), np.cos(t) * np.cos(z) ]).T basis = DoubleFourierSeries(M=1, N=1) values = basis.evaluate(grid.nodes, derivatives=np.array([0, 0, 0])) np.testing.assert_allclose(values, correct_vals, atol=1e-8)
def plot_vmec_comparison(vmec_data, equil): """Plots comparison of VMEC and DESC solutions Parameters ---------- vmec_data : dict dictionary of VMEC solution quantities. equil : dict dictionary of DESC equilibrium solution quantities. Returns ------- """ cR = equil.cR cZ = equil.cZ NFP = equil.NFP R_basis = equil.R_basis Z_basis = equil.Z_basis Nr = 8 Nt = 360 if np.max(R_basis.modes[:, 2]) == 0: Nz = 1 rows = 1 else: Nz = 6 rows = 2 Nr_vmec = vmec_data['rmnc'].shape[0]-1 s_idx = Nr_vmec % np.floor(Nr_vmec/(Nr-1)) idxes = np.linspace(s_idx, Nr_vmec, Nr).astype(int) if s_idx != 0: idxes = np.pad(idxes, (1, 0), mode='constant') Nr += 1 rho = np.sqrt(idxes/Nr_vmec) grid = LinearGrid(L=Nr, M=Nt, N=Nz, NFP=NFP, rho=rho, endpoint=True) R_transf = Transform(grid, R_basis) Z_transf = Transform(grid, Z_basis) R_desc = R_transf.transform(cR).reshape((Nr, Nt, Nz), order='F') Z_desc = Z_transf.transform(cZ).reshape((Nr, Nt, Nz), order='F') R_vmec, Z_vmec = vmec_interpolate( vmec_data['rmnc'][idxes], vmec_data['zmns'][idxes], vmec_data['xm'], vmec_data['xn'], np.unique(grid.nodes[:, 1]), np.unique(grid.nodes[:, 2])) plt.figure() for k in range(Nz): ax = plt.subplot(rows, int(Nz/rows), k+1) ax.plot(R_vmec[0, 0, k], Z_vmec[0, 0, k], 'bo') s_vmec = ax.plot(R_vmec[:, :, k].T, Z_vmec[:, :, k].T, 'b-') ax.plot(R_desc[0, 0, k], Z_desc[0, 0, k], 'ro') s_desc = ax.plot(R_desc[:, :, k].T, Z_desc[:, :, k].T, 'r--') ax.axis('equal') ax.set_xlabel('R') ax.set_ylabel('Z') if k == 0: s_vmec[0].set_label('VMEC') s_desc[0].set_label('DESC') ax.legend(fontsize='xx-small') plt.show()
def plot_comparison(equil0, equil1, label0='x0', label1='x1', **kwargs): """Plots force balance error Parameters ---------- equil0, equil1 : dict dictionary of two equilibrium solution quantities label0, label1 : str labels for each equilibria **kwargs : additional plot formatting parameters Returns ------- """ cR0 = equil0.cR cZ0 = equil0.cZ NFP0 = equil0.NFP R_basis0 = equil0.R_basis Z_basis0 = equil0.Z_basis cR1 = equil1.cR cZ1 = equil1.cZ NFP1 = equil1.NFP R_basis1 = equil1.R_basis Z_basis1 = equil1.Z_basis if NFP0 == NFP1: NFP = NFP0 else: raise ValueError( TextColors.FAIL + "NFP must be the same for both solutions" + TextColors.ENDC) if max(np.max(R_basis0.modes[:, 2]), np.max(R_basis1.modes[:, 2])) == 0: Nz = 1 rows = 1 else: Nz = 6 rows = 2 Nr = kwargs.get('Nr', 8) Nt = kwargs.get('Nt', 13) NNr = 100 NNt = 360 # constant rho surfaces grid_r = LinearGrid(L=Nr, M=NNt, N=Nz, NFP=NFP, endpoint=True) R_transf_0r = Transform(grid_r, R_basis0) Z_transf_0r = Transform(grid_r, Z_basis0) R_transf_1r = Transform(grid_r, R_basis1) Z_transf_1r = Transform(grid_r, Z_basis1) # constant theta surfaces grid_t = LinearGrid(L=NNr, M=Nt, N=Nz, NFP=NFP, endpoint=True) R_transf_0t = Transform(grid_t, R_basis0) Z_transf_0t = Transform(grid_t, Z_basis0) R_transf_1t = Transform(grid_t, R_basis1) Z_transf_1t = Transform(grid_t, Z_basis1) R0r = R_transf_0r.transform(cR0).reshape((Nr, NNt, Nz), order='F') Z0r = Z_transf_0r.transform(cZ0).reshape((Nr, NNt, Nz), order='F') R1r = R_transf_1r.transform(cR1).reshape((Nr, NNt, Nz), order='F') Z1r = Z_transf_1r.transform(cZ1).reshape((Nr, NNt, Nz), order='F') R0v = R_transf_0t.transform(cR0).reshape((NNr, Nt, Nz), order='F') Z0v = Z_transf_0t.transform(cZ0).reshape((NNr, Nt, Nz), order='F') R1v = R_transf_1t.transform(cR1).reshape((NNr, Nt, Nz), order='F') Z1v = Z_transf_1t.transform(cZ1).reshape((NNr, Nt, Nz), order='F') plt.figure() for k in range(Nz): ax = plt.subplot(rows, int(Nz/rows), k+1) ax.plot(R0r[0, 0, k], Z0r[0, 0, k], 'bo') s0 = ax.plot(R0r[:, :, k].T, Z0r[:, :, k].T, 'b-') ax.plot(R0v[:, :, k], Z0v[:, :, k], 'b:') ax.plot(R1r[0, 0, k], Z1r[0, 0, k], 'ro') s1 = ax.plot(R1r[:, :, k].T, Z1r[:, :, k].T, 'r-') ax.plot(R1v[:, :, k], Z1v[:, :, k], 'r:') ax.axis('equal') ax.set_xlabel('R') ax.set_ylabel('Z') if k == 0: s0[0].set_label(label0) s1[0].set_label(label1) ax.legend(fontsize='xx-small') plt.show()
def vmec_error(equil, vmec_data, Nt=8, Nz=4): """Computes error in SFL coordinates compared to VMEC solution Parameters ---------- equil : dict dictionary of DESC equilibrium parameters vmec_data : dict dictionary of VMEC equilibrium parameters Nt : int number of poloidal angles to sample (Default value = 8) Nz : int number of toroidal angles to sample (Default value = 8) Returns ------- err : float average Euclidean distance between VMEC and DESC sample points """ ns = np.size(vmec_data['psi']) rho = np.sqrt(vmec_data['psi']) grid = LinearGrid(L=ns, M=Nt, N=Nz, NFP=equil.NFP, rho=rho) R_basis = equil.R_basis Z_basis = equil.Z_basis R_transf = Transform(grid, R_basis) Z_transf = Transform(grid, Z_basis) vartheta = np.unique(grid.nodes[:, 1]) phi = np.unique(grid.nodes[:, 2]) R_desc = R_transf.transform(equil.cR).reshape((ns, Nt, Nz), order='F') Z_desc = Z_transf.transform(equil.cZ).reshape((ns, Nt, Nz), order='F') print('Interpolating VMEC solution to sfl coordinates') R_vmec = np.zeros((ns, Nt, Nz)) Z_vmec = np.zeros((ns, Nt, Nz)) for k in range(Nz): # toroidal angle for i in range(ns): # flux surface theta = np.zeros((Nt, )) for j in range(Nt): # poloidal angle f0 = sfl_err(np.array([0]), vartheta[j], phi[k], vmec_data, i) f2pi = sfl_err(np.array([2 * np.pi]), vartheta[j], phi[k], vmec_data, i) flag = (sign(f0) + sign(f2pi)) / 2 args = (vartheta[j], phi[k], vmec_data, i, flag) t = fsolve(sfl_err, vartheta[j], args=args) if flag != 0: t = np.remainder(t + np.pi, 2 * np.pi) theta[j] = t # theta angle that corresponds to vartheta[j] R_vmec[i, :, k] = vmec_transf(vmec_data['rmnc'][i, :], vmec_data['xm'], vmec_data['xn'], theta, phi[k], trig='cos').flatten() Z_vmec[i, :, k] = vmec_transf(vmec_data['zmns'][i, :], vmec_data['xm'], vmec_data['xn'], theta, phi[k], trig='sin').flatten() if not vmec_data['sym']: R_vmec[i, :, k] += vmec_transf(vmec_data['rmns'][i, :], vmec_data['xm'], vmec_data['xn'], theta, phi[k], trig='sin').flatten() Z_vmec[i, :, k] += vmec_transf(vmec_data['zmnc'][i, :], vmec_data['xm'], vmec_data['xn'], theta, phi[k], trig='cos').flatten() print('{}%'.format((k + 1) / Nz * 100)) return np.mean(np.sqrt((R_vmec - R_desc)**2 + (Z_vmec - Z_desc)**2))
def solve_eq_continuation(inputs, checkpoint_filename=None, device=None): """Solves for an equilibrium by continuation method Follows this procedure to solve the equilibrium: 1. Creates an initial guess from the given inputs 2. Optimizes the equilibrium's flux surfaces by minimizing the given objective function. 3. Step up to higher resolution and perturb the previous solution 4. Repeat 2 and 3 until at desired resolution Parameters ---------- inputs : dict dictionary with input parameters defining problem setup and solver options checkpoint_filename : str or path-like file to save checkpoint data (Default value = None) device : jax.device or None device handle to JIT compile to (Default value = None) Returns ------- equil_fam : EquilibriaFamily Container object that contains a list of the intermediate solutions, as well as the final solution, stored as Equilibrium objects timer : Timer Timer object containing timing data for individual iterations """ timer = Timer() timer.start("Total time") stell_sym = inputs['stell_sym'] NFP = inputs['NFP'] Psi_lcfs = inputs['Psi_lcfs'] M = inputs['Mpol'] # arr N = inputs['Ntor'] # arr delta_lm = inputs['delta_lm'] # arr Mnodes = inputs['Mnodes'] # arr Nnodes = inputs['Nnodes'] # arr bdry_ratio = inputs['bdry_ratio'] # arr pres_ratio = inputs['pres_ratio'] # arr zeta_ratio = inputs['zeta_ratio'] # arr errr_ratio = inputs['errr_ratio'] # arr pert_order = inputs['pert_order'] # arr ftol = inputs['ftol'] # arr xtol = inputs['xtol'] # arr gtol = inputs['gtol'] # arr nfev = inputs['nfev'] # arr optim_method = inputs['optim_method'] errr_mode = inputs['errr_mode'] bdry_mode = inputs['bdry_mode'] zern_mode = inputs['zern_mode'] node_mode = inputs['node_mode'] cP = inputs['cP'] cI = inputs['cI'] axis = inputs['axis'] bdry = inputs['bdry'] verbose = inputs['verbose'] if checkpoint_filename is not None: checkpoint = True checkpoint_file = Checkpoint(checkpoint_filename, write_ascii=True) else: checkpoint = False if stell_sym: R_sym = Tristate(True) Z_sym = Tristate(False) L_sym = Tristate(False) else: R_sym = Tristate(None) Z_sym = Tristate(None) L_sym = Tristate(None) arr_len = M.size for ii in range(arr_len): if verbose > 0: print("================") print("Step {}/{}".format(ii + 1, arr_len)) print("================") print("Spectral resolution (M,N,delta_lm)=({},{},{})".format( M[ii], N[ii], delta_lm[ii])) print("Node resolution (M,N)=({},{})".format( Mnodes[ii], Nnodes[ii])) print("Boundary ratio = {}".format(bdry_ratio[ii])) print("Pressure ratio = {}".format(pres_ratio[ii])) print("Zeta ratio = {}".format(zeta_ratio[ii])) print("Error ratio = {}".format(errr_ratio[ii])) print("Perturbation Order = {}".format(pert_order[ii])) print("Function tolerance = {}".format(ftol[ii])) print("Gradient tolerance = {}".format(gtol[ii])) print("State vector tolerance = {}".format(xtol[ii])) print("Max function evaluations = {}".format(nfev[ii])) print("================") # initial solution # at initial soln, must: create bases, create grids, create transforms if ii == 0: timer.start("Iteration {} total".format(ii + 1)) inputs_ii = { 'L': delta_lm[ii], 'M': M[ii], 'N': N[ii], 'cP': cP * pres_ratio[ii], 'cI': cI, 'Psi': Psi_lcfs, 'NFP': NFP, 'bdry': bdry, 'sym': stell_sym, 'index': zern_mode, 'bdry_mode': bdry_mode, 'bdry_ratio': bdry_ratio[ii], 'axis': axis, 'output_path': checkpoint_filename } timer.start("Transform precomputation") if verbose > 0: print("Precomputing Transforms") equil_fam = EquilibriaFamily(inputs=inputs_ii) # Get initial Equilibrium from equil_fam equil = equil_fam[ii] x = equil.x # initial state vector # bases (extracted from Equilibrium) R_basis, Z_basis, L_basis, P_basis, I_basis = equil.R_basis, \ equil.Z_basis, \ equil.L_basis, \ equil.P_basis, \ equil.I_basis # grids RZ_grid = ConcentricGrid(Mnodes[ii], Nnodes[ii], NFP=NFP, sym=stell_sym, axis=False, index=zern_mode, surfs=node_mode) L_grid = LinearGrid(M=Mnodes[ii], N=2 * Nnodes[ii] + 1, NFP=NFP, sym=stell_sym) # transforms R_transform = Transform(RZ_grid, R_basis, derivs=3) Z_transform = Transform(RZ_grid, Z_basis, derivs=3) R1_transform = Transform(L_grid, R_basis) Z1_transform = Transform(L_grid, Z_basis) L_transform = Transform(L_grid, L_basis, derivs=0) P_transform = Transform(RZ_grid, P_basis, derivs=1) I_transform = Transform(RZ_grid, I_basis, derivs=1) timer.stop("Transform precomputation") if verbose > 1: timer.disp("Transform precomputation") # continuing from previous solution else: # change grids if Mnodes[ii] != Mnodes[ii - 1] or Nnodes[ii] != Nnodes[ii - 1]: RZ_grid = ConcentricGrid(Mnodes[ii], Nnodes[ii], NFP=NFP, sym=stell_sym, axis=False, index=zern_mode, surfs=node_mode) L_grid = LinearGrid(M=Mnodes[ii], N=2 * Nnodes[ii] + 1, NFP=NFP, sym=stell_sym) # change bases if M[ii] != M[ii - 1] or N[ii] != N[ ii - 1] or delta_lm[ii] != delta_lm[ii - 1]: equil.change_resolution( L=delta_lm[ii], M=M[ii], N=N[ii]) # update equilibrium bases to the new resolutions R_basis, Z_basis, L_basis = equil.R_basis, equil.Z_basis, equil.L_basis x = equil.x # change transform matrices timer.start("Iteration {} changing resolution".format(ii + 1)) if verbose > 0: print( "Changing node resolution from (Mnodes,Nnodes) = ({},{}) to ({},{})" .format(Mnodes[ii - 1], Nnodes[ii - 1], Mnodes[ii], Nnodes[ii])) print( "Changing spectral resolution from (L,M,N) = ({},{},{}) to ({},{},{})" .format(delta_lm[ii - 1], M[ii - 1], N[ii - 1], delta_lm[ii], M[ii], N[ii])) R_transform.change_resolution(grid=RZ_grid, basis=R_basis) Z_transform.change_resolution(grid=RZ_grid, basis=Z_basis) R1_transform.change_resolution(grid=L_grid, basis=R_basis) Z1_transform.change_resolution(grid=L_grid, basis=Z_basis) L_transform.change_resolution(grid=L_grid, basis=L_basis) P_transform.change_resolution(grid=RZ_grid) I_transform.change_resolution(grid=RZ_grid) timer.stop("Iteration {} changing resolution".format(ii + 1)) if verbose > 1: timer.disp("Iteration {} changing resolution".format(ii + 1)) # continuation parameters delta_bdry = bdry_ratio[ii] - bdry_ratio[ii - 1] delta_pres = pres_ratio[ii] - pres_ratio[ii - 1] delta_zeta = zeta_ratio[ii] - zeta_ratio[ii - 1] deltas = np.array([delta_bdry, delta_pres, delta_zeta]) # need a non-scalar objective function to do the perturbations obj_fun = ObjectiveFunctionFactory.get_equil_obj_fun( errr_mode, scalar=False, R_transform=R_transform, Z_transform=Z_transform, R1_transform=R1_transform, Z1_transform=Z1_transform, L_transform=L_transform, P_transform=P_transform, I_transform=I_transform) equil_obj = obj_fun.compute callback = obj_fun.callback args = (equil.cRb, equil.cZb, equil.cP, equil.cI, equil.Psi, bdry_ratio[ii - 1], pres_ratio[ii - 1], zeta_ratio[ii - 1], errr_ratio[ii - 1]) # TODO: should probably perturb before expanding resolution # perturbations if np.any(deltas): if verbose > 1: print("Perturbing equilibrium") x, timer = perturb_continuation_params(x, equil_obj, deltas, args, pert_order[ii], verbose, timer) # equilibrium objective function if optim_method in ['bfgs']: scalar = True else: scalar = False obj_fun = ObjectiveFunctionFactory.get_equil_obj_fun( errr_mode, scalar=scalar, R_transform=R_transform, Z_transform=Z_transform, R1_transform=R1_transform, Z1_transform=Z1_transform, L_transform=L_transform, P_transform=P_transform, I_transform=I_transform) equil_obj = obj_fun.compute callback = obj_fun.callback args = (equil.cRb, equil.cZb, equil.cP, equil.cI, equil.Psi, bdry_ratio[ii - 1], pres_ratio[ii - 1], zeta_ratio[ii - 1], errr_ratio[ii - 1]) if use_jax: if optim_method in ['bfgs']: jac = AutoDiffJacobian(equil_obj, argnum=0, mode='grad') else: jac = AutoDiffJacobian(equil_obj, argnum=0, mode='fwd') if verbose > 0: print("Compiling objective function") if device is None: import jax device = jax.devices()[0] equil_obj_jit = jit(equil_obj, static_argnums=(), device=device) jac_obj_jit = jit(jac.compute, device=device) timer.start("Iteration {} compilation".format(ii + 1)) f0 = equil_obj_jit(x, *args) J0 = jac_obj_jit(x, *args) timer.stop("Iteration {} compilation".format(ii + 1)) if verbose > 1: timer.disp("Iteration {} compilation".format(ii + 1)) else: equil_obj_jit = equil_obj jac_obj_jit = '2-point' if verbose > 0: print("Starting optimization") x_init = x timer.start("Iteration {} solution".format(ii + 1)) if optim_method in ['bfgs']: out = scipy.optimize.minimize(equil_obj_jit, x0=x_init, args=args, method=optim_method, jac=jac_obj_jit, tol=gtol[ii], options={ 'maxiter': nfev[ii], 'disp': verbose }) elif optim_method in ['trf', 'lm', 'dogleg']: out = scipy.optimize.least_squares(equil_obj_jit, x0=x_init, args=args, jac=jac_obj_jit, method=optim_method, x_scale='jac', ftol=ftol[ii], xtol=xtol[ii], gtol=gtol[ii], max_nfev=nfev[ii], verbose=verbose) else: raise NotImplementedError( TextColors.FAIL + "optim_method must be one of 'bfgs', 'trf', 'lm', 'dogleg'" + TextColors.ENDC) timer.stop("Iteration {} solution".format(ii + 1)) equil.x = out['x'] equil_fam.append(copy.deepcopy(equil)) if verbose > 1: timer.disp("Iteration {} solution".format(ii + 1)) timer.pretty_print( "Iteration {} avg time per step".format(ii + 1), timer["Iteration {} solution".format(ii + 1)] / out['nfev']) if verbose > 0: print("Start of Step {}:".format(ii + 1)) callback(x_init, *args) print("End of Step {}:".format(ii + 1)) callback(x, *args) if checkpoint: if verbose > 0: print('Saving latest iteration') equil_fam.save() if not is_nested(equil.cR, equil.cZ, equil.R_basis, equil.Z_basis): warnings.warn( TextColors.WARNING + 'WARNING: Flux surfaces are no longer nested, exiting early.' + 'Consider increasing errr_ratio or taking smaller perturbation steps' + TextColors.ENDC) break timer.stop("Total time") print('====================') print('Done') if verbose > 1: timer.disp("Total time") if checkpoint_filename is not None: print('Output written to {}'.format(checkpoint_filename)) print('====================') return equil_fam, timer