from apps.onedimensional.advection import advection
from pydogpack.utils import x_functions
from pydogpack import main
from pydogpack.visualize import plot


class SmoothSystemExample(problem.Problem):
    def __init__(self, wavespeed=1.0, initial_condition=None, source_function=None):
        self.wavespeed = wavespeed
        if initial_condition is None:
            self.initial_condition = x_functions.Sine()
        else:
            self.initial_condition = initial_condition

        app = advection.Advection(wavespeed, source_function)
        max_wavespeed = wavespeed
        exact_solution = advection.ExactSolution(initial_condition, wavespeed)

        super().__init__(
            app, initial_condition, max_wavespeed, exact_solution
        )


if __name__ == "__main__":
    wavespeed = 1.0
    initial_condition = x_functions.ComposedVector(
        [x_functions.Sine(), x_functions.Cosine()]
    )
    problem = SmoothSystemExample(wavespeed, initial_condition)
    final_solution = main.run(problem)
        if discontinuity_locations is None:
            self.discontinuity_locations = [-0.6, -0.4]
        else:
            self.discontinuity_locations = discontinuity_locations

        app_ = advection.Advection(wavespeed, source_function)
        riemann_problems = []
        for i in range(len(left_states)):
            riemann_problems.append(x_functions.RiemannProblem(
                left_states[i], right_states[i], discontinuity_locations[i]
            ))

        initial_condition = x_functions.ComposedVector(riemann_problems)
        max_wavespeed = wavespeed
        exact_solution = advection.ExactSolution(initial_condition, self.wavespeed)

        super().__init__(
            app_, initial_condition, max_wavespeed, exact_solution
        )


if __name__ == "__main__":
    wavespeed = 1.0
    left_state = [1.0, -2.0]
    right_state = [-1.0, 2.0]
    discontinuity_locations = [-0.6, -0.4]
    problem = RiemannSystemExample(
        wavespeed, left_state, right_state, discontinuity_locations
    )
    main.run(problem)