Exemplo n.º 1
0
    def get_theta_bounds_pyo(model: pyo.ConcreteModel, input):
        if model.component("obj"):
            model.del_component(model.obj)
        model.obj = pyo.Objective(expr=input[2], sense=pyo.maximize)
        result = Experiment.solve(model, solver=Experiment.use_solver)
        assert (result.solver.status
                == SolverStatus.ok) and (result.solver.termination_condition
                                         == TerminationCondition.optimal)
        max_theta = pyo.value(model.obj)

        model.del_component(model.obj)
        model.obj = pyo.Objective(expr=input[2], sense=pyo.minimize)
        result = Experiment.solve(model, solver=Experiment.use_solver)
        assert (result.solver.status
                == SolverStatus.ok) and (result.solver.termination_condition
                                         == TerminationCondition.optimal)
        min_theta = pyo.value(model.obj)

        model.del_component(model.obj)
        model.obj = pyo.Objective(expr=input[3], sense=pyo.maximize)
        result = Experiment.solve(model, solver=Experiment.use_solver)
        assert (result.solver.status
                == SolverStatus.ok) and (result.solver.termination_condition
                                         == TerminationCondition.optimal)
        max_theta_dot = pyo.value(model.obj)

        model.del_component(model.obj)
        model.obj = pyo.Objective(expr=input[3], sense=pyo.minimize)
        result = Experiment.solve(model, solver=Experiment.use_solver)
        assert (result.solver.status
                == SolverStatus.ok) and (result.solver.termination_condition
                                         == TerminationCondition.optimal)
        min_theta_dot = pyo.value(model.obj)
        return max_theta, min_theta, max_theta_dot, min_theta_dot
 def __init__(self):
     env_input_size: int = 6
     super().__init__(env_input_size)
     self.post_fn_remote = self.post_milp
     self.get_nn_fn = self.get_nn
     self.plot_fn = self.plot
     self.template_2d: np.ndarray = np.array([[1, 0, 0, 0, 0, 0],
                                              [1, -1, 0, 0, 0, 0]])
     input_boundaries, input_template = self.get_template(0)
     self.input_boundaries: List = input_boundaries
     self.input_template: np.ndarray = input_template
     _, template = self.get_template(1)
     self.analysis_template: np.ndarray = template
     collision_distance = 0
     distance = [Experiment.e(6, 0) - Experiment.e(6, 1)]
     # self.use_bfs = True
     # self.n_workers = 1
     self.rounding_value = 2**10
     self.use_rounding = False
     self.time_horizon = 40000
     self.unsafe_zone: List[Tuple] = [(distance,
                                       np.array([collision_distance]))]
     self.input_epsilon = 0
     # self.nn_path = os.path.join(utils.get_save_dir(),"tune_PPO_stopping_car/PPO_StoppingCar_acc24_00000_0_cost_fn=0,epsilon_input=0_2021-01-21_02-30-49/checkpoint_39/checkpoint-39")
     self.nn_path = os.path.join(
         utils.get_save_dir(),
         "tune_PPO_stopping_car/PPO_StoppingCar_acc24_00001_1_cost_fn=0,epsilon_input=0_2021-01-21_02-30-49/checkpoint_58/checkpoint-58"
     )
Exemplo n.º 3
0
    def post_milp(self, x, nn, output_flag, t, template):
        """milp method"""
        post = []
        for chosen_action in range(2):
            observable_template = self.observable_templates[chosen_action]
            observable_result = self.observable_results[chosen_action]
            gurobi_model = grb.Model()
            gurobi_model.setParam('OutputFlag', output_flag)
            gurobi_model.setParam('Threads', 2)
            input = Experiment.generate_input_region(gurobi_model, template, x,
                                                     self.env_input_size)
            observation = gurobi_model.addMVar(shape=(2, ),
                                               lb=float("-inf"),
                                               ub=float("inf"),
                                               name="input")
            gurobi_model.addConstr(
                observation[1] <= input[0] - input[1] + self.input_epsilon / 2,
                name=f"obs_constr21")
            gurobi_model.addConstr(
                observation[1] >= input[0] - input[1] - self.input_epsilon / 2,
                name=f"obs_constr22")
            gurobi_model.addConstr(
                observation[0] <= input[2] - input[3] + self.input_epsilon / 2,
                name=f"obs_constr11")
            gurobi_model.addConstr(
                observation[0] >= input[2] - input[3] - self.input_epsilon / 2,
                name=f"obs_constr12")
            # feasible_action = Experiment.generate_nn_guard(gurobi_model, observation, nn, action_ego=chosen_action)
            # feasible_action = Experiment.generate_nn_guard(gurobi_model, input, nn, action_ego=chosen_action)

            Experiment.generate_region_constraints(gurobi_model,
                                                   observable_template,
                                                   observation,
                                                   observable_result, 2)
            gurobi_model.optimize()
            feasible_action = gurobi_model.status == 2
            if feasible_action:
                # apply dynamic
                # x_prime_results = self.optimise(template, gurobi_model, input)  # h representation
                # x_prime = Experiment.generate_input_region(gurobi_model, template, x_prime_results, self.env_input_size)
                x_second = StoppingCarExperiment.apply_dynamic(
                    input,
                    gurobi_model,
                    action=chosen_action,
                    env_input_size=self.env_input_size)
                gurobi_model.update()
                gurobi_model.optimize()
                found_successor, x_second_results = self.h_repr_to_plot(
                    gurobi_model, template, x_second)
                if found_successor:
                    post.append(tuple(x_second_results))
        return post
Exemplo n.º 4
0
 def post_milp(self, x, nn, output_flag, t, template):
     """milp method"""
     post = []
     for chosen_action in range(2):
         gurobi_model = grb.Model()
         gurobi_model.setParam('OutputFlag', output_flag)
         gurobi_model.setParam('Threads', 2)
         input = Experiment.generate_input_region(gurobi_model, template, x,
                                                  self.env_input_size)
         max_theta, min_theta, max_theta_dot, min_theta_dot = self.get_theta_bounds(
             gurobi_model, input)
         sin_cos_table = self.get_sin_cos_table(max_theta,
                                                min_theta,
                                                max_theta_dot,
                                                min_theta_dot,
                                                action=chosen_action)
         feasible_action = CartpoleExperiment.generate_nn_guard(
             gurobi_model, input, nn, action_ego=chosen_action)
         if feasible_action:
             thetaacc, xacc = CartpoleExperiment.generate_angle_milp(
                 gurobi_model, input, sin_cos_table)
             # apply dynamic
             x_prime = self.apply_dynamic(
                 input,
                 gurobi_model,
                 thetaacc=thetaacc,
                 xacc=xacc,
                 env_input_size=self.env_input_size)
             gurobi_model.update()
             gurobi_model.optimize()
             found_successor, x_prime_results = self.h_repr_to_plot(
                 gurobi_model, template, x_prime)
             if found_successor:
                 post.append(tuple(x_prime_results))
     return post
Exemplo n.º 5
0
 def get_template(self, mode=0):
     p = Experiment.e(self.env_input_size, 0)
     v = Experiment.e(self.env_input_size, 1)
     if mode == 0:  # box directions with intervals
         # input_boundaries = [0, 0, 10, 10]
         input_boundaries = [9, -8, 0, 0.1]
         # optimise in a direction
         template = []
         for dimension in range(self.env_input_size):
             template.append(Experiment.e(self.env_input_size, dimension))
             template.append(-Experiment.e(self.env_input_size, dimension))
         template = np.array(template)  # the 6 dimensions in 2 variables
         return input_boundaries, template
     if mode == 1:  # directions to easily find fixed point
         input_boundaries = None
         template = np.array([v + p, -v - p, -p])
         return input_boundaries, template
Exemplo n.º 6
0
 def generate_nn_polyhedral_guard(self, nn, chosen_action, output_flag):
     gurobi_model = grb.Model()
     gurobi_model.setParam('OutputFlag', output_flag)
     gurobi_model.setParam('Threads', 2)
     observation = gurobi_model.addMVar(shape=(2, ),
                                        lb=float("-inf"),
                                        ub=float("inf"),
                                        name="observation")
     Experiment.generate_nn_guard(gurobi_model,
                                  observation,
                                  nn,
                                  action_ego=chosen_action)
     observable_template = Experiment.octagon(2)
     self.env_input_size = 2
     observable_result = self.optimise(observable_template, gurobi_model,
                                       observation)
     self.env_input_size = 6
     return observable_template, observable_result
Exemplo n.º 7
0
 def __init__(self):
     env_input_size: int = 2
     super().__init__(env_input_size)
     self.post_fn_remote = self.post_milp
     self.get_nn_fn = self.get_nn
     self.plot_fn = self.plot
     self.template_2d: np.ndarray = np.array([[0, 1], [1, 0]])
     input_boundaries, input_template = self.get_template(0)
     self.input_boundaries: List = input_boundaries
     self.input_template: np.ndarray = input_template
     _, template = self.get_template(0)
     self.analysis_template: np.ndarray = template
     self.time_horizon = 500
     self.rounding_value = 2**8
     p = Experiment.e(self.env_input_size, 0)
     v = Experiment.e(self.env_input_size, 1)
     self.unsafe_zone: List[Tuple] = [([p, -v, v], np.array([0, 1, 0]))]
     self.nn_path = os.path.join(
         utils.get_save_dir(),
         "tune_PPO_bouncing_ball/PPO_BouncingBall_c7326_00000_0_2021-01-16_05-43-36/checkpoint_36/checkpoint-36"
     )
Exemplo n.º 8
0
def run_parameterised_experiment(config):
    # Hyperparameters
    trial_dir = tune.get_trial_dir()
    problem, method, other_config = config["main_params"]
    n_workers = config["n_workers"]

    experiment = CartpoleExperiment()
    experiment.nn_path = other_config[
        "folder"]  # nn_paths_cartpole[other_config["nn_path"]]
    experiment.tau = other_config["tau"]
    if other_config["template"] == 2:  # octagon
        experiment.analysis_template = Experiment.octagon(
            experiment.env_input_size)
    elif other_config["template"] == 0:  # box
        experiment.analysis_template = Experiment.box(
            experiment.env_input_size)
    else:
        _, template = experiment.get_template(1)
        experiment.analysis_template = template  # standard
    experiment.n_workers = n_workers
    experiment.show_progressbar = False
    experiment.show_progress_plot = False
    # experiment.use_rounding = False
    experiment.save_dir = trial_dir
    experiment.update_progress_fn = update_progress
    elapsed_seconds, safe, max_t = experiment.run_experiment()

    safe_value = 0
    if safe is None:
        safe_value = 0
    elif safe:
        safe_value = 1
    elif not safe:
        safe_value = -1
    tune.report(elapsed_seconds=elapsed_seconds,
                safe=safe_value,
                max_t=max_t,
                done=True)
Exemplo n.º 9
0
 def get_template(self, mode=0):
     x = Experiment.e(self.env_input_size, 0)
     x_dot = Experiment.e(self.env_input_size, 1)
     theta = Experiment.e(self.env_input_size, 2)
     theta_dot = Experiment.e(self.env_input_size, 3)
     if mode == 0:  # box directions with intervals
         # input_boundaries = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05]
         input_boundaries = [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05]
         # input_boundaries = [0.04373426, -0.04373426, -0.04980056, 0.04980056, 0.045, -0.045, -0.51, 0.51]
         # optimise in a direction
         template = []
         for dimension in range(self.env_input_size):
             template.append(Experiment.e(self.env_input_size, dimension))
             template.append(-Experiment.e(self.env_input_size, dimension))
         template = np.array(template)  # the 6 dimensions in 2 variables
         return input_boundaries, template
     if mode == 1:  # directions to easily find fixed point
         input_boundaries = None
         template = np.array([
             theta, -theta, theta_dot, -theta_dot, theta + theta_dot,
             -(theta + theta_dot), (theta - theta_dot), -(theta - theta_dot)
         ])  # x_dot, -x_dot,theta_dot - theta
         return input_boundaries, template
     if mode == 2:
         input_boundaries = None
         template = np.array([theta, -theta, theta_dot, -theta_dot])
         return input_boundaries, template
     if mode == 3:
         input_boundaries = None
         template = np.array([theta, theta_dot, -theta_dot])
         return input_boundaries, template
     if mode == 4:
         input_boundaries = [0.09375, 0.625, 0.625, 0.0625, 0.1875]
         # input_boundaries = [0.09375, 0.5, 0.5, 0.0625, 0.09375]
         template = np.array([
             theta, theta_dot, -theta_dot, theta + theta_dot,
             (theta - theta_dot)
         ])
         return input_boundaries, template
     if mode == 5:
         input_boundaries = [0.125, 0.0625, 0.1875]
         template = np.array(
             [theta, theta + theta_dot, (theta - theta_dot)])
         return input_boundaries, template
    def post_milp(self, x, nn, output_flag, t, template):
        post = []
        observable_template_action1 = self.observable_templates[1]
        observable_result_action1 = self.observable_results[1]
        observable_template_action0 = self.observable_templates[0]
        observable_result_action0 = self.observable_results[0]

        def standard_op():
            gurobi_model = grb.Model()
            gurobi_model.setParam('OutputFlag', output_flag)
            input = self.generate_input_region(gurobi_model, template, x,
                                               self.env_input_size)
            z = self.apply_dynamic(input, gurobi_model, self.env_input_size)
            return gurobi_model, z, input

        # case 0
        gurobi_model, z, input = standard_op()
        feasible0 = self.generate_guard(gurobi_model, z, case=0)  # bounce
        if feasible0:  # action is irrelevant in this case
            # apply dynamic
            x_prime_results = self.optimise(template, gurobi_model, z)
            gurobi_model = grb.Model()
            gurobi_model.setParam('OutputFlag', output_flag)
            input2 = self.generate_input_region(gurobi_model, template,
                                                x_prime_results,
                                                self.env_input_size)
            x_second = self.apply_dynamic2(input2,
                                           gurobi_model,
                                           case=0,
                                           env_input_size=self.env_input_size)
            found_successor, x_second_results = self.h_repr_to_plot(
                gurobi_model, template, x_second)
            if found_successor:
                post.append(tuple(x_second_results))

        # case 1 : ball going down and hit
        gurobi_model, z, input = standard_op()
        feasible11 = self.generate_guard(gurobi_model, z, case=1)
        if feasible11:
            Experiment.generate_region_constraints(
                gurobi_model, observable_template_action1, input,
                observable_result_action1, 2)
            gurobi_model.optimize()
            feasible12 = gurobi_model.status == 2
            # feasible12 = self.generate_nn_guard(gurobi_model, input, nn, action_ego=1)  # check for action =1 over input (not z!)
            if feasible12:
                # apply dynamic
                x_prime_results = self.optimise(template, gurobi_model, z)
                gurobi_model = grb.Model()
                gurobi_model.setParam('OutputFlag', output_flag)
                input2 = self.generate_input_region(gurobi_model, template,
                                                    x_prime_results,
                                                    self.env_input_size)
                x_second = self.apply_dynamic2(
                    input2,
                    gurobi_model,
                    case=1,
                    env_input_size=self.env_input_size)
                found_successor, x_second_results = self.h_repr_to_plot(
                    gurobi_model, template, x_second)
                if found_successor:
                    post.append(tuple(x_second_results))
        # case 2 : ball going up and hit
        gurobi_model, z, input = standard_op()
        feasible21 = self.generate_guard(gurobi_model, z, case=2)
        if feasible21:
            Experiment.generate_region_constraints(
                gurobi_model, observable_template_action1, input,
                observable_result_action1, 2)
            gurobi_model.optimize()
            feasible22 = gurobi_model.status == 2
            # feasible22 = self.generate_nn_guard(gurobi_model, input, nn, action_ego=1)  # check for action =1 over input (not z!)
            if feasible22:
                # apply dynamic
                x_prime_results = self.optimise(template, gurobi_model, z)
                gurobi_model = grb.Model()
                gurobi_model.setParam('OutputFlag', output_flag)
                input2 = self.generate_input_region(gurobi_model, template,
                                                    x_prime_results,
                                                    self.env_input_size)
                x_second = self.apply_dynamic2(
                    input2,
                    gurobi_model,
                    case=2,
                    env_input_size=self.env_input_size)
                found_successor, x_second_results = self.h_repr_to_plot(
                    gurobi_model, template, x_second)
                if found_successor:
                    post.append(tuple(x_second_results))
        # case 1 alt : ball going down and NO hit
        gurobi_model, z, input = standard_op()
        feasible11_alt = self.generate_guard(gurobi_model, z, case=1)
        if feasible11_alt:
            Experiment.generate_region_constraints(
                gurobi_model, observable_template_action0, input,
                observable_result_action0, 2)
            gurobi_model.optimize()
            feasible12_alt = gurobi_model.status == 2
            # feasible12_alt = self.generate_nn_guard(gurobi_model, input, nn, action_ego=0)  # check for action = 0 over input (not z!)
            if feasible12_alt:
                # apply dynamic
                x_prime_results = self.optimise(template, gurobi_model, z)
                gurobi_model = grb.Model()
                gurobi_model.setParam('OutputFlag', output_flag)
                input2 = self.generate_input_region(gurobi_model, template,
                                                    x_prime_results,
                                                    self.env_input_size)
                x_second = self.apply_dynamic2(
                    input2,
                    gurobi_model,
                    case=3,
                    env_input_size=self.env_input_size)
                found_successor, x_second_results = self.h_repr_to_plot(
                    gurobi_model, template, x_second)

                if found_successor:
                    post.append(tuple(x_second_results))
        # case 2 alt : ball going up and NO hit
        gurobi_model, z, input = standard_op()
        feasible21_alt = self.generate_guard(gurobi_model, z, case=2)
        if feasible21_alt:
            Experiment.generate_region_constraints(
                gurobi_model, observable_template_action0, input,
                observable_result_action0, 2)
            gurobi_model.optimize()
            feasible22_alt = gurobi_model.status == 2
            # feasible22_alt = self.generate_nn_guard(gurobi_model, input, nn, action_ego=0)  # check for action = 0 over input (not z!)
            if feasible22_alt:
                # apply dynamic
                x_prime_results = self.optimise(template, gurobi_model, z)
                gurobi_model = grb.Model()
                gurobi_model.setParam('OutputFlag', output_flag)
                input2 = self.generate_input_region(gurobi_model, template,
                                                    x_prime_results,
                                                    self.env_input_size)
                x_second = self.apply_dynamic2(
                    input2,
                    gurobi_model,
                    case=3,
                    env_input_size=self.env_input_size)
                found_successor, x_second_results = self.h_repr_to_plot(
                    gurobi_model, template, x_second)
                if found_successor:
                    post.append(tuple(x_second_results))
        # case 3 : ball out of reach and not bounce
        gurobi_model, z, input = standard_op()
        feasible3 = self.generate_guard(gurobi_model, z,
                                        case=3)  # out of reach
        if feasible3:  # action is irrelevant in this case
            # apply dynamic
            x_prime_results = self.optimise(template, gurobi_model, z)
            gurobi_model = grb.Model()
            gurobi_model.setParam('OutputFlag', output_flag)
            input2 = self.generate_input_region(gurobi_model, template,
                                                x_prime_results,
                                                self.env_input_size)
            x_second = self.apply_dynamic2(input2,
                                           gurobi_model,
                                           case=3,
                                           env_input_size=self.env_input_size)
            found_successor, x_second_results = self.h_repr_to_plot(
                gurobi_model, template, x_second)
            if found_successor:
                post.append(tuple(x_second_results))

        return post
Exemplo n.º 11
0
    def post_milp(self, x, nn, output_flag, t, template):
        """milp method"""
        post = []
        for chosen_action in range(2):
            observable_template = self.observable_templates[chosen_action]
            observable_result = self.observable_results[chosen_action]
            if USE_GUROBI:
                gurobi_model = grb.Model()
                gurobi_model.setParam('OutputFlag', output_flag)
                gurobi_model.setParam('Threads', 2)
                input = Experiment.generate_input_region(
                    gurobi_model, template, x, self.env_input_size)
                Experiment.generate_region_constraints(
                    gurobi_model,
                    observable_template,
                    input,
                    observable_result,
                    env_input_size=self.env_input_size)
                gurobi_model.optimize()
                feasible_action = gurobi_model.status == 2
                if feasible_action:
                    max_theta, min_theta, max_theta_dot, min_theta_dot = self.get_theta_bounds(
                        gurobi_model, input)
                    sin_cos_table = self.get_sin_cos_table(
                        max_theta,
                        min_theta,
                        max_theta_dot,
                        min_theta_dot,
                        action=chosen_action,
                        step_thetaacc=100)
                    x_prime_results = self.optimise(template, gurobi_model,
                                                    input)  # h representation
                    x_prime = Experiment.generate_input_region(
                        gurobi_model, template, x_prime_results,
                        self.env_input_size)
                    thetaacc, xacc = CartpoleExperiment.generate_angle_milp(
                        gurobi_model, x_prime, sin_cos_table)
                    # apply dynamic
                    x_second = self.apply_dynamic(
                        x_prime,
                        gurobi_model,
                        thetaacc=thetaacc,
                        xacc=xacc,
                        env_input_size=self.env_input_size)
                    gurobi_model.update()
                    gurobi_model.optimize()
                    found_successor, x_second_results = self.h_repr_to_plot(
                        gurobi_model, template, x_second)
                    if found_successor:
                        post.append(tuple(x_second_results))
            else:
                model = pyo.ConcreteModel()
                input = Experiment.generate_input_region_pyo(
                    model, template, x, self.env_input_size)
                feasible_action = ORACartpoleExperiment.generate_nn_guard_pyo(
                    model, input, nn, action_ego=chosen_action, M=1e04)
                if feasible_action:  # performs action 2 automatically when battery is dead
                    max_theta, min_theta, max_theta_dot, min_theta_dot = self.get_theta_bounds_pyo(
                        model, input)
                    sin_cos_table = self.get_sin_cos_table(
                        max_theta,
                        min_theta,
                        max_theta_dot,
                        min_theta_dot,
                        action=chosen_action,
                        step_thetaacc=100)
                    x_prime_results = self.optimise_pyo(template, model, input)
                    x_prime = Experiment.generate_input_region_pyo(
                        model,
                        template,
                        x_prime_results,
                        self.env_input_size,
                        name="x_prime_input")
                    thetaacc, xacc = ORACartpoleExperiment.generate_angle_milp_pyo(
                        model, x_prime, sin_cos_table)

                    model.del_component(model.obj)
                    model.obj = pyo.Objective(expr=thetaacc,
                                              sense=pyo.maximize)
                    result = Experiment.solve(model,
                                              solver=Experiment.use_solver)
                    assert (result.solver.status == SolverStatus.ok) and (
                        result.solver.termination_condition
                        == TerminationCondition.optimal
                    ), f"LP wasn't optimally solved {x}"
                    # apply dynamic
                    x_second = self.apply_dynamic_pyo(
                        x_prime,
                        model,
                        thetaacc=thetaacc,
                        xacc=xacc,
                        env_input_size=self.env_input_size,
                        action=chosen_action)
                    x_second_results = self.optimise_pyo(
                        template, model, x_second)
                    found_successor = x_prime_results is not None
                    if found_successor:
                        post.append((tuple(x_second_results)))
        return post
Exemplo n.º 12
0
def run_parameterised_experiment(config):
    # Hyperparameters
    trial_dir = tune.get_trial_dir()
    problem, method, other_config = config["main_params"]
    n_workers = config["n_workers"]
    if problem == "bouncing_ball":
        if method == "ora":
            experiment = ORABouncingBallExperiment()
        else:
            experiment = BouncingBallExperiment()
        experiment.nn_path = os.path.join(utils.get_save_dir(), nn_paths_bouncing_ball[other_config["nn_path"]])
        experiment.tau = other_config["tau"]
        if other_config["template"] == 2:  # octagon
            experiment.analysis_template = Experiment.octagon(experiment.env_input_size)
        elif other_config["template"] == 0:  # box
            experiment.analysis_template = Experiment.box(experiment.env_input_size)
        else:
            raise NotImplementedError()
        experiment.n_workers = n_workers
        experiment.show_progressbar = False
        experiment.show_progress_plot = False
        experiment.save_dir = trial_dir
        experiment.update_progress_fn = update_progress
        elapsed_seconds, safe, max_t = experiment.run_experiment()
    elif problem == "stopping_car":
        if method == "ora":
            experiment = ORAStoppingCarExperiment()
            experiment.nn_path = os.path.join(utils.get_save_dir(), nn_paths_stopping_car[other_config["nn_path"]])
            experiment.input_epsilon = other_config["epsilon_input"]
            if other_config["template"] == 2:  # octagon
                experiment.analysis_template = Experiment.octagon(experiment.env_input_size)
            elif other_config["template"] == 0:  # box
                experiment.analysis_template = Experiment.box(experiment.env_input_size)
            else:
                _, template = experiment.get_template(1)
                experiment.analysis_template = template  # standard
            experiment.n_workers = n_workers
            experiment.show_progressbar = False
            experiment.show_progress_plot = False
            experiment.save_dir = trial_dir
            experiment.update_progress_fn = update_progress
            elapsed_seconds, safe, max_t = experiment.run_experiment()
        else:
            experiment = StoppingCarExperiment()
            experiment.nn_path = os.path.join(utils.get_save_dir(), nn_paths_stopping_car[other_config["nn_path"]])
            experiment.input_epsilon = other_config["epsilon_input"]
            if other_config["template"] == 2:  # octagon
                experiment.analysis_template = Experiment.octagon(experiment.env_input_size)
            elif other_config["template"] == 0:  # box
                experiment.analysis_template = Experiment.box(experiment.env_input_size)
            else:
                _, template = experiment.get_template(1)
                experiment.analysis_template = template  # standard
            experiment.n_workers = n_workers
            experiment.show_progressbar = False
            experiment.show_progress_plot = False
            experiment.save_dir = trial_dir
            experiment.update_progress_fn = update_progress
            elapsed_seconds, safe, max_t = experiment.run_experiment()
    else:
        if method == "ora":
            experiment = ORACartpoleExperiment()
        else:
            experiment = CartpoleExperiment()

        experiment.nn_path = os.path.join(utils.get_save_dir(), nn_paths_cartpole[other_config["nn_path"]])
        experiment.tau = other_config["tau"]
        if other_config["template"] == 2:  # octagon
            experiment.analysis_template = Experiment.octagon(experiment.env_input_size)
        elif other_config["template"] == 0:  # box
            experiment.analysis_template = Experiment.box(experiment.env_input_size)
        else:
            _, template = experiment.get_template(1)
            experiment.analysis_template = template  # standard
        experiment.n_workers = n_workers
        experiment.show_progressbar = False
        experiment.show_progress_plot = False
        # experiment.use_rounding = False
        experiment.save_dir = trial_dir
        experiment.update_progress_fn = update_progress
        elapsed_seconds, safe, max_t = experiment.run_experiment()
    safe_value = 0
    if safe is None:
        safe_value = 0
    elif safe:
        safe_value = 1
    elif not safe:
        safe_value = -1
    tune.report(elapsed_seconds=elapsed_seconds, safe=safe_value, max_t=max_t, done=True)
    def get_template(self, mode=0):
        x_lead = Experiment.e(6, 0)
        x_ego = Experiment.e(6, 1)
        v_lead = Experiment.e(6, 2)
        v_ego = Experiment.e(6, 3)
        a_lead = Experiment.e(6, 4)
        a_ego = Experiment.e(6, 5)
        if mode == 0:  # box directions with intervals
            input_boundaries = [
                50, -40, 10, -0, 28, -28, 36, -36, 0, -0, 0, -0, 0
            ]
            # optimise in a direction
            template = []
            for dimension in range(6):
                template.append(Experiment.e(6, dimension))
                template.append(-Experiment.e(6, dimension))
            template = np.array(template)  # the 6 dimensions in 2 variables

            # t1 = [0] * 6
            # t1[0] = -1
            # t1[1] = 1
            # template = np.vstack([template, t1])
            return input_boundaries, template
        if mode == 1:  # directions to easily find fixed point

            input_boundaries = [20]

            template = np.array([
                a_lead, -a_lead, a_ego, -a_ego, -v_lead, v_lead,
                -(v_lead - v_ego), (v_lead - v_ego), -(x_lead - x_ego),
                (x_lead - x_ego)
            ])
            return input_boundaries, template
        if mode == 2:
            input_boundaries = [
                0, -100, 30, -31, 20, -30, 0, -35, 0, -0, -10, -10, 20
            ]
            # optimise in a direction
            template = []
            for dimension in range(6):
                t1 = [0] * 6
                t1[dimension] = 1
                t2 = [0] * 6
                t2[dimension] = -1
                template.append(t1)
                template.append(t2)
            # template = np.array([[0, 1], [1, 1], [1, 0], [1, -1], [0, -1], [-1, -1], [-1, 0], [-1, 1]])  # the 8 dimensions in 2 variables
            template = np.array(template)  # the 6 dimensions in 2 variables

            t1 = [0] * 6
            t1[0] = 1
            t1[1] = -1
            template = np.vstack([template, t1])
            return input_boundaries, template
        if mode == 3:  # single point box directions +diagonal
            input_boundaries = [
                30, -30, 0, -0, 28, -28, 36, -36, 0, -0, 0, -0, 0
            ]
            # optimise in a direction
            template = []
            for dimension in range(6):
                t1 = [0] * 6
                t1[dimension] = 1
                t2 = [0] * 6
                t2[dimension] = -1
                template.append(t1)
                template.append(t2)
            # template = np.array([[0, 1], [1, 1], [1, 0], [1, -1], [0, -1], [-1, -1], [-1, 0], [-1, 1]])  # the 8 dimensions in 2 variables
            template = np.array(template)  # the 6 dimensions in 2 variables

            t1 = [0] * 6
            t1[0] = -1
            t1[1] = 1
            template = np.vstack([template, t1])
            return input_boundaries, template
        if mode == 4:  # octagon, every pair of variables
            input_boundaries = [20]
            template = []
            for dimension in range(6):
                t1 = [0] * 6
                t1[dimension] = 1
                t2 = [0] * 6
                t2[dimension] = -1
                template.append(t1)
                template.append(t2)
                for other_dimension in range(dimension + 1, 6):
                    t1 = [0] * 6
                    t1[dimension] = 1
                    t1[other_dimension] = -1
                    t2 = [0] * 6
                    t2[dimension] = -1
                    t2[other_dimension] = 1
                    t3 = [0] * 6
                    t3[dimension] = 1
                    t3[other_dimension] = 1
                    t4 = [0] * 6
                    t4[dimension] = -1
                    t4[other_dimension] = -1
                    template.append(t1)
                    template.append(t2)
                    template.append(t3)
                    template.append(t4)
            return input_boundaries, np.array(template)
    def get_nn(self):
        config = get_PPO_config(1234)
        trainer = ppo.PPOTrainer(config=config)
        trainer.restore(self.nn_path)
        policy = trainer.get_policy()
        sequential_nn = convert_ray_policy_to_sequential(policy).cpu()
        # l0 = torch.nn.Linear(6, 2, bias=False)
        # l0.weight = torch.nn.Parameter(torch.tensor([[0, 0, 1, -1, 0, 0], [1, -1, 0, 0, 0, 0]], dtype=torch.float32))
        # layers = [l0]
        # for l in sequential_nn:
        #     layers.append(l)
        #
        # nn = torch.nn.Sequential(*layers)
        nn = sequential_nn
        # ray.shutdown()
        return nn


if __name__ == '__main__':
    ray.init(log_to_driver=False, local_mode=False)
    experiment = StoppingCarExperiment()
    experiment.plotting_time_interval = 60 * 2
    experiment.show_progressbar = True
    experiment.show_progress_plot = False
    template = Experiment.octagon(experiment.env_input_size)
    experiment.analysis_template = template  # standard
    input_boundaries = [40, -30, 10, -0, 28, -28, 36, -36, 0, -0, 0, -0, 0]
    experiment.input_boundaries = input_boundaries
    experiment.time_horizon = 150
    experiment.run_experiment()