from tudatpy.kernel import constants
from tudatpy.kernel.interface import spice_interface
from tudatpy.kernel.simulation import environment_setup
from tudatpy.kernel.simulation import propagation_setup

# Define simulation time and step size
simulation_start_epoch = 0.0
simulation_end_epoch = constants.JULIAN_DAY
fixed_step_size = 100.0

# Load spice kernels.
spice_interface.load_standard_kernels()

# Create body objects
bodies_to_create = ["Sun"]

body_settings = environment_setup.get_default_body_settings(
    bodies_to_create, simulation_start_epoch, simulation_end_epoch,
    fixed_step_size)

bodies = environment_setup.create_bodies(body_settings)

# Create vehicle object
bodies["Apollo"] = environment_setup.Body()

# Set mass of vehicle
bodies["Apollo"].set_constant_body_mass(2000.0)

global_frame_origin = "SSB"
global_frame_orientation = "ECLIPJ2000"
environment_setup.set_global_frame_body_ephemerides(bodies,
def main():
    """
    The problem describes the orbit design around a small body (asteroid Itokawa).

    DYNAMICAL MODEL
    Itokawa spherical harmonics, cannonball radiation pressure from Sun, point-mass third-body
    from Sun, Jupiter, Saturn, Earth, Mars

    PROPAGATION TIME
    5 days

    INTEGRATOR
    RKF7(8) with tolerances 1E-8

    TERMINATION CONDITIONS
    In addition to 5 day time, minimum distance from Itokaw's center of mass: 150 m (no crashing),
    maximum distance from center of mass: 5 km (no escaping)

    DESIGN VARIABLES
    Initial values of semi-major axis, eccentricity, inclination, and longitude of node

    OBJECTIVES
    1. good coverage: the mean value of the absolute longitude w.r.t. Itokawa over the full propagation should be
       maximized;
    2. close orbit: the mean value of the distance should be minimized.
    """
    ###########################################################################
    # CREATE SIMULATION SETTINGS ##############################################
    ###########################################################################

    # Load spice kernels
    spice_interface.load_standard_kernels()

    # Define Itokawa radius
    itokawa_radius = 161.915

    # Set simulation start and end epochs
    mission_initial_time = 0.0
    mission_duration = 5.0 * 86400.0

    # Set boundaries on the design variables
    design_variable_lb = (300, 0.0, 0.0, 0.0)
    design_variable_ub = (2000, 0.3, 180, 360)

    # Set termination conditions
    minimum_distance_from_com = 150.0 + itokawa_radius
    maximum_distance_from_com = 5.0E3 + itokawa_radius

    # Create simulation bodies
    bodies = create_simulation_bodies(itokawa_radius)

    ###########################################################################
    # CREATE ACCELERATIONS ####################################################
    ###########################################################################

    bodies_to_propagate = ["Spacecraft"]
    central_bodies = ["Itokawa"]

    # Create acceleration models.
    acceleration_models = get_acceleration_models(bodies_to_propagate,
                                                  central_bodies, bodies)

    # Create numerical integrator settings.
    integrator_settings = propagation_setup.integrator.runge_kutta_variable_step_size(
        mission_initial_time, 1.0,
        propagation_setup.integrator.RKCoefficientSets.rkf_78, 1.0E-6, 86400.0,
        1.0E-8, 1.0E-8)

    ###########################################################################
    # CREATE PROPAGATION SETTINGS #############################################
    ###########################################################################

    # Define list of dependent variables to save
    dependent_variables_to_save = get_dependent_variables_to_save()

    # Create propagation settings
    termination_settings = get_termination_settings(mission_initial_time,
                                                    mission_duration,
                                                    minimum_distance_from_com,
                                                    maximum_distance_from_com)

    # Define (Cowell) propagator settings with mock initial state
    propagator_settings = propagation_setup.propagator.translational(
        central_bodies,
        acceleration_models,
        bodies_to_propagate,
        np.zeros(6),
        termination_settings,
        output_variables=dependent_variables_to_save)

    ###########################################################################
    # OPTIMIZE ORBIT WITH PYGMO ###############################################
    ###########################################################################

    # Fix seed for reproducibility
    fixed_seed = 17031861
    # Instantiate orbit problem
    orbitProblem = AsteroidOrbitProblem(bodies, integrator_settings,
                                        propagator_settings,
                                        mission_initial_time, mission_duration,
                                        design_variable_lb, design_variable_ub)

    # Select Moead algorithm from pygmo, with one generation
    algo = pg.algorithm(pg.nsga2(gen=1, seed=fixed_seed))
    # Create pygmo problem using the UDP instantiated above
    prob = pg.problem(orbitProblem)
    # Initialize pygmo population with 48 individuals
    population_size = 48
    pop = pg.population(prob, size=population_size, seed=fixed_seed)
    # Set the number of evolutions
    number_of_evolutions = 50
    # Initialize containers
    fitness_list = []
    population_list = []
    # Evolve the population recursively
    for gen in range(number_of_evolutions):
        print('Evolving population; at generation ' + str(gen))
        # Evolve the population
        pop = algo.evolve(pop)
        # Store the fitness values and design variables for all individuals
        fitness_list.append(pop.get_f())
        population_list.append(pop.get_x())

    ###########################################################################
    # ANALYZE FIRST AND LAST GENERATIONS ######################################
    ###########################################################################

    dump_results_to_file = False

    # Get output path
    output_path = os.getcwd() + '/PygmoExampleSimulationOutput/'

    # Retrieve first and last generations for further analysis
    pops_to_analyze = {0: 'initial', number_of_evolutions - 1: 'final'}
    # Initialize containers
    simulation_output = dict()
    # Loop over first and last generations
    for population_index, population_name in pops_to_analyze.items():
        current_population = population_list[population_index]
        # Save fitness and population members
        if dump_results_to_file:
            # Create directory
            if not os.path.isdir(output_path):
                os.mkdir(output_path)
            np.savetxt(output_path + 'Fitness_' + population_name + '.dat',
                       fitness_list[population_index])
            np.savetxt(output_path + 'Population_' + population_name + '.dat',
                       population_list[population_index])
        # Current generation's dictionary
        generation_output = dict()
        # Loop over all individuals of the populations
        for individual in range(population_size):
            # Retrieve orbital parameters
            current_orbit_parameters = current_population[individual]
            # Propagate orbit and compute fitness
            orbitProblem.fitness(current_orbit_parameters)
            # Retrieve state and dependent variable history
            current_states = orbitProblem.get_last_run_dynamics_simulator(
            ).state_history
            current_dependent_variables = orbitProblem.get_last_run_dynamics_simulator(
            ).dependent_variable_history
            # Save results to dict
            generation_output[individual] = [
                current_states, current_dependent_variables
            ]
            # Write data to files
            if dump_results_to_file:
                save2txt(
                    current_dependent_variables, population_name +
                    '_dependent_variables' + str(individual) + '.dat',
                    output_path)
                save2txt(
                    current_states,
                    population_name + '_states' + str(individual) + '.dat',
                    output_path)
        # Append to global dictionary
        simulation_output[population_index] = [
            generation_output, fitness_list[population_index],
            population_list[population_index]
        ]

    ###########################################################################
    # ANALYZE RESULTS #########################################################
    ###########################################################################

    # Set font size for plots
    font = {'size': 12}
    matplotlib.rc('font', **font)

    # Create dictionaries
    decision_variable_names = {
        0: 'Semi-major axis [m]',
        1: 'Eccentricity',
        2: 'Inclination [deg]',
        3: 'Longitude of the node [deg]'
    }
    decision_variable_range = {
        0: [800.0, 1300.0],
        1: [0.10, 0.17],
        2: [90.0, 95.0],
        3: [250.0, 270.0]
    }
    decision_variable_symbols = {
        0: r'$a$',
        1: r'$e$',
        2: r'$i$',
        3: r'$\Omega$'
    }
    decision_variable_units = {0: r' m', 1: r' ', 2: r' deg', 3: r' deg'}
    # Loop over populations
    for population_index in simulation_output.keys():
        # Retrieve current population
        current_generation = simulation_output[population_index]
        # Plot Pareto fronts for all design variables
        fig, axs = plt.subplots(2, 2, figsize=(14, 8))
        fig.suptitle('Generation ' + str(population_index),
                     fontweight='bold',
                     y=0.95)
        current_fitness = current_generation[1]
        current_population = current_generation[2]
        for ax_index, ax in enumerate(axs.flatten()):
            cs = ax.scatter(np.deg2rad(current_fitness[:, 0]),
                            current_fitness[:, 1],
                            40,
                            current_population[:, ax_index],
                            marker='.')
            cbar = fig.colorbar(cs, ax=ax)
            cbar.ax.set_ylabel(decision_variable_names[ax_index])
            ax.grid('major')
            if ax_index > 1:
                ax.set_xlabel(r'Objective 1: coverage [$deg^{-1}$] ')
            if ax_index == 0 or ax_index == 2:
                ax.set_ylabel(r'Objective 2: proximity [$m$]')
        # Save figure
        fig.savefig('pareto_generation_' + str(population_index) + '.png',
                    bbox_inches='tight')

    # Plot histogram for last generation, semi-major axis
    fig, axs = plt.subplots(2, 2, figsize=(12, 8))
    fig.suptitle('Final orbits by decision variable',
                 fontweight='bold',
                 y=0.95)
    last_pop = simulation_output[number_of_evolutions - 1][2]
    for ax_index, ax in enumerate(axs.flatten()):
        ax.hist(last_pop[:, ax_index], bins=30)
        # Prettify
        ax.set_xlabel(decision_variable_names[ax_index])
        if ax_index % 2 == 0:
            ax.set_ylabel('Occurrences in the population')
    # Save figure
    fig.savefig('histograms_final_generation.png', bbox_inches='tight')

    # Plot orbits of initial and final generation
    fig = plt.figure(figsize=(12, 6))
    fig.suptitle('Initial and final orbit bundle', fontweight='bold', y=0.95)
    title = {0: 'Initial orbit bundle', 1: 'Final orbit bundle'}
    # Loop over populations
    for ax_index, population_index in enumerate(simulation_output.keys()):
        current_ax = fig.add_subplot(1, 2, 1 + ax_index, projection='3d')
        # Retrieve current population
        current_generation = simulation_output[population_index]
        current_population = current_generation[2]
        # Loop over individuals
        for ind_index, individual in enumerate(current_population):
            # Plot orbit
            state_history = list(current_generation[0][ind_index][0].values())
            state_history = np.vstack(state_history)
            current_ax.plot(state_history[:, 0],
                            state_history[:, 1],
                            state_history[:, 2],
                            linewidth=0.5)
        # Prettify
        current_ax.set_xlabel('X [m]')
        current_ax.set_ylabel('Y [m]')
        current_ax.set_zlabel('Z [m]')
        current_ax.set_title(title[ax_index], y=1.0, pad=15)
    # Save figure
    fig.savefig('orbit_bundles_initial_final_gen.png', bbox_inches='tight')

    # Plot orbits of final generation divided by parameters
    fig = plt.figure(figsize=(12, 8))
    fig.suptitle('Final orbit bundle by decision variable',
                 fontweight='bold',
                 y=0.95)
    # Retrieve current population
    current_generation = simulation_output[number_of_evolutions - 1]
    # Plot Pareto fronts for all design variables
    current_population = current_generation[2]
    # Loop over decision variables
    for var in range(4):
        # Create axis
        current_ax = fig.add_subplot(2, 2, 1 + var, projection='3d')
        # Loop over individuals
        for ind_index, individual in enumerate(current_population):
            # Set plot color according to boundaries
            if individual[var] < decision_variable_range[var][0]:
                plt_color = 'r'
                label = decision_variable_symbols[var] + ' < ' + str(decision_variable_range[var][0]) + \
                        decision_variable_units[var]
            elif decision_variable_range[var][0] < individual[
                    var] < decision_variable_range[var][1]:
                plt_color = 'b'
                label = str(decision_variable_range[var][0]) + ' < ' + \
                        decision_variable_symbols[var] + \
                        ' < ' + str(decision_variable_range[var][1]) + decision_variable_units[var]
            else:
                plt_color = 'g'
                label = decision_variable_symbols[var] + ' > ' + str(decision_variable_range[var][1]) + \
                        decision_variable_units[var]

            # Plot orbit
            state_history = list(current_generation[0][ind_index][0].values())
            state_history = np.vstack(state_history)
            current_ax.plot(state_history[:, 0],
                            state_history[:, 1],
                            state_history[:, 2],
                            color=plt_color,
                            linewidth=0.5,
                            label=label)
        # Prettify
        current_ax.set_xlabel('X [m]')
        current_ax.set_ylabel('Y [m]')
        current_ax.set_zlabel('Z [m]')
        current_ax.set_title(decision_variable_names[var], y=1.0, pad=10)
        handles, decision_variable_legend = current_ax.get_legend_handles_labels(
        )
        decision_variable_legend, ids = np.unique(decision_variable_legend,
                                                  return_index=True)
        handles = [handles[i] for i in ids]
        current_ax.legend(handles,
                          decision_variable_legend,
                          loc='lower right',
                          bbox_to_anchor=(0.3, 0.6))
    # Save figure
    fig.savefig('orbit_bundle_final_gen_by_variable.png', bbox_inches='tight')

    # Show plot
    plt.show()