def main():
    
    T = 300.0 # Simulation temperature
    dt = 1 * units.fs # MD timestep
    nsteps = 100000 # MD number of steps
    mixing = [1,-1,0] # [1.0, -1.0, 0.3] # mixing weights for "real" and ML forces
    lengthscale = 0.5 # KRR Gaussian width.
    gamma = 1 / (2 * lengthscale**2)
    grid_spacing = 0.05
    #     mlmodel = GaussianProcess(corr='squared_exponential', 
    #         # theta0=1e-1, thetaL=1e-4, thetaU=1e+2,
    #         theta0=1., 
    #         random_start=100, normalize=False, nugget=1.0e-2)
    mlmodel = KernelRidge(kernel='rbf', 
                          gamma=gamma, gammaL = gamma/4, gammaU=2*gamma,
                           alpha=1.0e-2, variable_noise=False, max_lhood=False)
    anglerange = sp.arange(0, 2*sp.pi + grid_spacing, grid_spacing)
    X_grid = sp.array([[sp.array([x,y]) for x in anglerange]
                       for y in anglerange]).reshape((len(anglerange)**2, 2))
    ext_field = None # IgnoranceField(X_grid, y_threshold=-0.075, cutoff = 3.)
                           
    # Bootstrap from initial database? uncomment
    # data = sp.loadtxt('phi_psi_F.csv')
    # # data[:,:2] -= 0.025 # fix because of old round_vector routine
    # mlmodel.fit(data[:,:2], data[:,2])
    # ext_field.update_cost(mlmodel.X_fit_, mlmodel.y)

    # mlmodel.fit(X_grid, sp.zeros(len(X_grid)))
    # mlmodel.fit(sp.load('X_fitD.npy'), sp.load('y_fitD.npy'))
    
    # Prepare diagnostic visual effects.
    plt.close('all')
    plt.ion()
    fig, ax = plt.subplots(1, 2, figsize=(24, 13))
    
    atoms = ase.io.read('myplum.xyz')
    with open('data.input', 'r') as file:
        lammpsdata = file.readlines()

    # Set temperature
    MaxwellBoltzmannDistribution(atoms, 0.5 * units.kB * T, force_temp=True)
    # Set total momentum to zero
    p = atoms.get_momenta()
    p -= p.sum(axis=0) / len(atoms)
    atoms.set_momenta(p)
    atoms.rescale_velocities(T)
    
    # Select MD propagator
    mdpropagator = Langevin(atoms, dt, T*units.kB, 1.0e-2, fixcm=True)
    # mdpropagator = MLVerlet(atoms, dt, T)

    # Zero-timestep evaluation and data files setup.
    print("START")
    pot_energy, f = calc_lammps(atoms, preloaded_data=lammpsdata)
    mlmodel.accumulate_data(round_vector(atoms.colvars(), precision=grid_spacing), 0.)
    # mlmodel.accumulate_data(round_vector(atoms.colvars(), precision=grid_spacing), pot_energy)
    printenergy(atoms, pot_energy)
    try:
        os.remove('atomstraj.xyz')
    except:
        pass
    traj = open("atomstraj.xyz", 'a')
    atoms.write(traj, format='extxyz')
    results, traj_buffer = [], []

    # When in the simulation to update the ML fit -- optional.
    teaching_points = sp.unique((sp.linspace(0, nsteps**(1/3), nsteps/20)**3).astype('int') + 1)

    # MD Loop
    for istep in range(nsteps):
        # Flush Cholesky decomposition of K
        if istep % 1000 == 0:
            mlmodel.Cho_L = None
            mlmodel.max_lhood = False
        print("Dihedral angles | phi = %.3f, psi = %.3f " % (atoms.phi(), atoms.psi()))
        do_update = (istep % 60 == 59)
        t = get_time()
        mdpropagator.halfstep_1of2(f)
        print("TIMER 001 | %.3f" % (get_time() - t))
        t = get_time()
        f, pot_energy, _ = get_all_forces(atoms, mlmodel, grid_spacing, T, extfield=ext_field, mixing=mixing, lammpsdata=lammpsdata, do_update=do_update)
        if do_update and mlmodel.max_lhood:
            mlmodel.max_lhood = False
        mdpropagator.halfstep_2of2(f)
        print("TIMER 002 | %.3f" % (get_time() - t))


        # manual cooldown!!!
        if sp.absolute(atoms.get_kinetic_energy() / (1.5 * units.kB * atoms.get_number_of_atoms()) - T) > 100:
            atoms.rescale_velocities(T)

        printenergy(atoms, pot_energy/atoms.get_number_of_atoms(), step=istep)
        # if do_update:
        #     try:
        #         print("Lengthscale = %.3e, Noise = %.3e" % (1/(2 * mlmodel.gamma)**0.5, mlmodel.noise.mean()))
        #     except:
        #         print("")
        if istep % 60 == 59:
            t = get_time()
            if 'datasetplot' not in locals():
                datasetplot = pl.Plot_datapts(ax[0], mlmodel)
            else:
                datasetplot.update()
            if hasattr(mlmodel, 'dual_coef_'):
                if 'my2dplot' not in locals():
                    my2dplot = pl.Plot_energy_n_point(ax[1], mlmodel, atoms.colvars().ravel())
                else:
                    my2dplot.update_prediction()
                    my2dplot.update_current_point(atoms.colvars().ravel())
            print("TIMER 003 | %.03f" % (get_time() - t))
            t = get_time()
            fig.canvas.draw()
            print("TIMER 004 | %.03f" % (get_time() - t))
            # fig.canvas.print_figure('current.png')
        t = get_time()
        # traj_buffer.append(atoms.copy())
        # if istep % 100 == 0:
        #     for at in traj_buffer:
        #         atoms.write(traj, format='extxyz')
        #     traj_buffer = []
        results.append(sp.array([atoms.phi(), atoms.psi(), pot_energy]))
        print("TIMER 005 | %.03f" % (get_time() - t))        
    traj.close()
    print("FINISHED")
    sp.savetxt('results.csv', sp.array(results))
    sp.savetxt('mlmodel.dual_coef_.csv', mlmodel.dual_coef_)
    sp.savetxt('mlmodel.X_fit_.csv', mlmodel.X_fit_)
    sp.savetxt('mlmodel.y.csv', mlmodel.y)
    calc = None
    
    return mlmodel