def __init__(self,
                 env,
                 safety_measure,
                 greed,
                 step_size,
                 discount_rate,
                 safety_threshold,
                 x_seed,
                 y_seed,
                 gp_params=None,
                 keep_seed_in_data=True):
        """
        Initializer
        :param env: the environment
        :param safety_measure: either SafetyTruth or SafetyModel of the environment
        :param greed: the epsilon parameter of the ConstrainedEpsilonGreedy policy
        :param step_size: the step size in the Q-Learning update
        :param discount_rate: the discount rate
        :param safety_threshold: the lambda threshold used to evaluate safety. This is 0 theoretically, but an Agent
            that is at the exact boundary of the viability kernel still fails due to rounding errors. Hence, this should
            be a small, positive value.
        :param x_seed: the seed input of the GP
        :param y_seed: the seed output of the GP
        :param gp_params: the parameters defining the GP. See edge.models.inference.MaternGP for more information
        :param keep_seed_in_data: whether to keep the seed data in the GP dataset. Should be True, otherwise GPyTorch
            fails.
        """
        Q_model = GPQLearning(env,
                              step_size,
                              discount_rate,
                              x_seed=x_seed,
                              y_seed=y_seed,
                              gp_params=gp_params)
        super(ConstrainedQLearner, self).__init__(env, Q_model)

        self.Q_model = Q_model
        self.safety_measure = safety_measure
        self.constrained_value_policy = ConstrainedEpsilonGreedy(
            self.env.stateaction_space, greed)
        self.safety_maximization_policy = SafetyMaximization(
            self.safety_measure.stateaction_space)
        self.safety_threshold = safety_threshold
        self.keep_seed_in_data = keep_seed_in_data
        if not keep_seed_in_data:
            self.Q_model.empty_data()
Example #2
0
class DiscreteQLearner(Agent):
    """
    Defines an Agent doing Q-Learning on a Discrete StateActionSpace. The Agent can also have a
    model for the underlying safety measure, either via a SafetyTruth or a SafetyModel, but this model is then
    not updated. If it has a safety model, the Agent then acts with a ConstrainedEpsilonGreedy policy, staying in the
    safe set. Otherwise, it uses an EpsilonGreedy policy
    """
    def __init__(self,
                 env,
                 greed,
                 step_size,
                 discount_rate,
                 constraint=None,
                 safety_threshold=0.05):
        """
        Initializer
        :param env: the environment
        :param greed: the epsilon parameter in the EpsilonGreedy/ConstrainedEpsilonGreedy policy
        :param step_size: the Q-Learning step size
        :param discount_rate: the discount rate
        :param constraint: either None, SafetyTruth, or SafetyModel. The model of safety, if any
        :param safety_threshold: the lambda threshold used to evaluate safety. This is 0 theoretically, but an Agent
            that is at the exact boundary of the viability kernel still fails due to rounding errors. Hence, this should
            be a small, positive value.
        """
        Q_model = QLearning(env.stateaction_space, step_size, discount_rate)

        super(DiscreteQLearner, self).__init__(env, Q_model)

        self.Q_model = Q_model
        self.constraint = constraint
        self.safety_threshold = safety_threshold
        if self.constraint is None:
            self.policy = EpsilonGreedy(self.env.stateaction_space, greed)
            self.default_policy = None
            self.is_constrained = False
        else:
            self.policy = ConstrainedEpsilonGreedy(self.env.stateaction_space,
                                                   greed)
            self.default_policy = EpsilonGreedy(self.env.stateaction_space,
                                                greed)
            self.is_constrained = True

    @property
    def greed(self):
        """
        Returns the epsilon parameter of the ConstrainedEpsilonGreedy/EpsilonGreedy policy
        :return: epsilon parameter
        """
        return self.policy.greed

    @greed.setter
    def greed(self, new_greed):
        """
        Sets the epsilon parameter of the ConstrainedEpsilonGreedy/EpsilonGreedy policy
        """
        self.policy.greed = new_greed
        if self.default_policy is not None:
            self.default_policy.greed = new_greed

    def get_next_action(self):
        q_values = self.Q_model[self.state, :]
        if self.is_constrained:
            all_actions = self.Q_model.space.action_space[:].reshape(-1, 1)
            action_is_viable = np.array([
                self.constraint.measure(self.state, a) > self.safety_threshold
                for a in all_actions
            ])

            action = self.policy.get_action(q_values, action_is_viable)
            if action is None:
                action = self.default_policy(q_values)
        else:
            action = self.policy(q_values)
        return action

    def get_random_safe_state(self):
        """
        Returns a random state that is classified as safe by the safety model
        :return: a safe state
        """
        if not self.is_constrained:
            return None
        else:
            is_viable = self.constraint.state_measure > self.safety_threshold
            viable_indexes = np.argwhere(is_viable).squeeze()
            state_index = viable_indexes[np.random.choice(len(viable_indexes))]
            state = self.constraint.stateaction_space.state_space[state_index]
            return state

    def update_models(self, state, action, next_state, reward, failed):
        self.Q_model.update(state, action, next_state, reward, failed)
Example #3
0
class ConstrainedQLearner(Agent):
    """
    Defines an Agent modelling the Q-Values with a MaternGP updated with the Q-Learning update. The Agent also has a
    model for the underlying safety measure, either via a SafetyTruth or a SafetyModel, but this model is then
    not updated. The Agent then acts with a ConstrainedEpsilonGreedy policy, staying in the safe set.
    """
    def __init__(self,
                 env,
                 safety_measure,
                 greed,
                 step_size,
                 discount_rate,
                 safety_threshold,
                 x_seed,
                 y_seed,
                 gp_params=None,
                 keep_seed_in_data=True):
        """
        Initializer
        :param env: the environment
        :param safety_measure: either SafetyTruth or SafetyModel of the environment
        :param greed: the epsilon parameter of the ConstrainedEpsilonGreedy policy
        :param step_size: the step size in the Q-Learning update
        :param discount_rate: the discount rate
        :param safety_threshold: the lambda threshold used to evaluate safety. This is 0 theoretically, but an Agent
            that is at the exact boundary of the viability kernel still fails due to rounding errors. Hence, this should
            be a small, positive value.
        :param x_seed: the seed input of the GP
        :param y_seed: the seed output of the GP
        :param gp_params: the parameters defining the GP. See edge.models.inference.MaternGP for more information
        :param keep_seed_in_data: whether to keep the seed data in the GP dataset. Should be True, otherwise GPyTorch
            fails.
        """
        Q_model = GPQLearning(env.stateaction_space,
                              step_size,
                              discount_rate,
                              x_seed=x_seed,
                              y_seed=y_seed,
                              gp_params=gp_params)
        super(ConstrainedQLearner, self).__init__(env, Q_model)

        self.Q_model = Q_model
        self.safety_measure = safety_measure
        self.constrained_value_policy = ConstrainedEpsilonGreedy(
            self.env.stateaction_space, greed)
        self.safety_maximization_policy = SafetyMaximization(
            self.safety_measure.stateaction_space)
        self.safety_threshold = safety_threshold
        self.keep_seed_in_data = keep_seed_in_data
        if not keep_seed_in_data:
            self.Q_model.empty_data()

    @property
    def greed(self):
        """
        Returns the epsilon parameter of the ConstrainedEpsilonGreedy policy
        :return: epsilon parameter
        """
        return self.constrained_value_policy.greed

    @greed.setter
    def greed(self, new_greed):
        """
        Sets the epsilon parameter of the ConstrainedEpsilonGreedy policy
        """
        self.constrained_value_policy.greed = new_greed

    def get_next_action(self):
        all_actions = self.Q_model.space.action_space[:].reshape(-1, 1)
        action_is_viable = np.array([
            self.safety_measure.measure(self.state, a) > self.safety_threshold
            for a in all_actions
        ])
        q_values = self.Q_model[self.state, :]

        action = self.constrained_value_policy.get_action(
            q_values, action_is_viable)
        if action is None:
            print('No viable action, taking the safest...')
            safety_values = self.safety_measure.measure(
                self.state, slice(None, None, None))
            action = self.safety_maximization_policy.get_action(safety_values)
        if action is None:
            raise NoActionError('The agent could not find a suitable action')

        return action

    def get_random_safe_state(self):
        """
        Returns a random state that is classified as safe by the safety model
        :return: a safe state
        """
        is_viable = self.safety_measure.state_measure > 0
        viable_indexes = np.argwhere(is_viable).squeeze()
        state_index = viable_indexes[np.random.choice(len(viable_indexes))]
        state = self.safety_measure.stateaction_space.state_space[state_index]
        return state

    def update_models(self, state, action, next_state, reward, failed):
        self.Q_model.update(state, action, next_state, reward, failed)

    def fit_models(self, train_x, train_y, epochs, **optimizer_kwargs):
        self.Q_model.fit(train_x, train_y, epochs, **optimizer_kwargs)
        if not self.keep_seed_in_data:
            self.Q_model.empty_data()
Example #4
0
    def __init__(self,
                 env,
                 greed,
                 step_size,
                 discount_rate,
                 q_x_seed,
                 q_y_seed,
                 gamma_optimistic,
                 gamma_cautious,
                 lambda_cautious,
                 s_x_seed,
                 s_y_seed,
                 q_gp_params=None,
                 s_gp_params=None,
                 keep_seed_in_data=True):
        """
        Initializer
        :param env: the environment
        :param greed: the epsilon parameter of the ConstrainedEpsilonGreedy
            policy
        :param step_size: the step size in the Q-Learning update
        :param discount_rate: the discount rate
        :param q_x_seed: the seed input of the GP for the Q-Values model
        :param q_y_seed: the seed output of the GP for the Q-Values model
        :param gamma_optimistic: the gamma parameter for Q_optimistic
        :param gamma_cautious: the gamma parameter for Q_cautious
        :param lambda_cautious: the lambda parameter for Q_cautious
        :param s_x_seed: the seed input of the GP for the safety model
        :param s_y_seed: the seed output of the GP for the safety model
        :param q_gp_params: the parameters defining the GP for the Q-Values
            model. See edge.models.inference.MaternGP
            for more information
        :param q_gp_params: the parameters defining the GP for the safety model.
            See edge.models.inference.MaternGP for more information
        :param keep_seed_in_data: whether to keep the seed data in the GPs
            datasets. Should be True, otherwise GPyTorch fails.
        """
        self.lambda_cautious_start, self.lambda_cautious_end = lambda_cautious
        self.gamma_cautious_start, self.gamma_cautious_end = gamma_cautious
        self.gamma_optimistic_start, self.gamma_optimistic_end = \
            gamma_optimistic
        self.lambda_cautious = self.lambda_cautious_start
        self.gamma_cautious = self.gamma_cautious_start

        self._step_size_decrease_index = 1

        Q_model = GPQLearning(env.stateaction_space,
                              step_size,
                              discount_rate,
                              x_seed=q_x_seed,
                              y_seed=q_y_seed,
                              gp_params=q_gp_params)
        safety_model = MaternSafety(env.stateaction_space,
                                    self.gamma_optimistic_start,
                                    x_seed=s_x_seed,
                                    y_seed=s_y_seed,
                                    gp_params=s_gp_params)
        super(ValuesAndSafetyCombinator, self).__init__(
            env=env,
            greed=greed,  # Unused: we define another policy
            step_size=step_size,
            discount_rate=discount_rate,
            x_seed=q_x_seed,
            y_seed=q_y_seed,
            gp_params=q_gp_params,
            keep_seed_in_data=keep_seed_in_data)

        self.Q_model = Q_model
        self.safety_model = safety_model

        self.constrained_value_policy = ConstrainedEpsilonGreedy(
            self.env.stateaction_space, greed)
        self.safety_maximization_policy = SafetyMaximization(
            self.env.stateaction_space)
        self._training_greed = self.greed

        self.keep_seed_in_data = keep_seed_in_data
        if not keep_seed_in_data:
            self.Q_model.empty_data()
Example #5
0
    def __init__(self,
                 env,
                 greed,
                 step_size,
                 discount_rate,
                 q_x_seed,
                 q_y_seed,
                 gamma_optimistic,
                 gamma_cautious,
                 lambda_cautious,
                 s_x_seed,
                 s_y_seed,
                 q_gp_params=None,
                 s_gp_params=None,
                 keep_seed_in_data=True):
        """
        Initializer
        :param env: the environment
        :param greed: the epsilon parameter of the ConstrainedEpsilonGreedy policy
        :param q_step_size: the step size in the Q-Learning update
        :param discount_rate: the discount rate
        :param q_x_seed: the seed input of the GP for the Q-Values model
        :param q_y_seed: the seed output of the GP for the Q-Values model
        :param gamma_optimistic: the gamma parameter for Q_optimistic
        :param gamma_cautious: the gamma parameter for Q_cautious
        :param lambda_cautious: the lambda parameter for Q_cautious
        :param s_x_seed: the seed input of the GP for the safety model
        :param s_y_seed: the seed output of the GP for the safety model
        :param q_gp_params: the parameters defining the GP for the Q-Values model. See edge.models.inference.MaternGP
            for more information
        :param q_gp_params: the parameters defining the GP for the safety model. See edge.models.inference.MaternGP
            for more information
        :param keep_seed_in_data: whether to keep the seed data in the GPs datasets. Should be True, otherwise GPyTorch
            fails.
        """
        Q_model = GPQLearning(env.stateaction_space,
                              step_size,
                              discount_rate,
                              x_seed=q_x_seed,
                              y_seed=q_y_seed,
                              gp_params=q_gp_params)
        safety_model = MaternSafety(env.stateaction_space,
                                    gamma_optimistic,
                                    x_seed=s_x_seed,
                                    y_seed=s_y_seed,
                                    gp_params=s_gp_params)
        super(EpsCorlLearner, self).__init__(env, Q_model, safety_model)

        self.Q_model = Q_model
        self.safety_model = safety_model
        self.lambda_cautious = lambda_cautious
        self.gamma_cautious = gamma_cautious
        self._gamma_optimistic = gamma_optimistic

        self.constrained_value_policy = ConstrainedEpsilonGreedy(
            self.env.stateaction_space, 0)
        self.safety_maximization_policy = SafetyMaximization(
            self.env.stateaction_space)
        self.active_sampling_policy = SafetyActiveSampling(
            self.env.stateaction_space)

        self.keep_seed_in_data = keep_seed_in_data
        if not keep_seed_in_data:
            self.Q_model.empty_data()

        self.has_explored = None
        self._greed = greed
Example #6
0
class EpsCorlLearner(Agent):
    """
        Defines an Agent modelling the Q-Values with a MaternGP updated with the Q-Learning update. The Agent also has a
        model for the underlying safety measure with a SafetyModel, and the Q-Learning update is constrained to stay in the
        current estimate of the safe set. The exploration is either the Q-Learning update is taken, or the safety measure
        update is.
        """
    def __init__(self,
                 env,
                 greed,
                 step_size,
                 discount_rate,
                 q_x_seed,
                 q_y_seed,
                 gamma_optimistic,
                 gamma_cautious,
                 lambda_cautious,
                 s_x_seed,
                 s_y_seed,
                 q_gp_params=None,
                 s_gp_params=None,
                 keep_seed_in_data=True):
        """
        Initializer
        :param env: the environment
        :param greed: the epsilon parameter of the ConstrainedEpsilonGreedy policy
        :param q_step_size: the step size in the Q-Learning update
        :param discount_rate: the discount rate
        :param q_x_seed: the seed input of the GP for the Q-Values model
        :param q_y_seed: the seed output of the GP for the Q-Values model
        :param gamma_optimistic: the gamma parameter for Q_optimistic
        :param gamma_cautious: the gamma parameter for Q_cautious
        :param lambda_cautious: the lambda parameter for Q_cautious
        :param s_x_seed: the seed input of the GP for the safety model
        :param s_y_seed: the seed output of the GP for the safety model
        :param q_gp_params: the parameters defining the GP for the Q-Values model. See edge.models.inference.MaternGP
            for more information
        :param q_gp_params: the parameters defining the GP for the safety model. See edge.models.inference.MaternGP
            for more information
        :param keep_seed_in_data: whether to keep the seed data in the GPs datasets. Should be True, otherwise GPyTorch
            fails.
        """
        Q_model = GPQLearning(env.stateaction_space,
                              step_size,
                              discount_rate,
                              x_seed=q_x_seed,
                              y_seed=q_y_seed,
                              gp_params=q_gp_params)
        safety_model = MaternSafety(env.stateaction_space,
                                    gamma_optimistic,
                                    x_seed=s_x_seed,
                                    y_seed=s_y_seed,
                                    gp_params=s_gp_params)
        super(EpsCorlLearner, self).__init__(env, Q_model, safety_model)

        self.Q_model = Q_model
        self.safety_model = safety_model
        self.lambda_cautious = lambda_cautious
        self.gamma_cautious = gamma_cautious
        self._gamma_optimistic = gamma_optimistic

        self.constrained_value_policy = ConstrainedEpsilonGreedy(
            self.env.stateaction_space, 0)
        self.safety_maximization_policy = SafetyMaximization(
            self.env.stateaction_space)
        self.active_sampling_policy = SafetyActiveSampling(
            self.env.stateaction_space)

        self.keep_seed_in_data = keep_seed_in_data
        if not keep_seed_in_data:
            self.Q_model.empty_data()

        self.has_explored = None
        self._greed = greed

    @property
    def greed(self):
        """
        Returns the epsilon parameter balancing the exploration/exploitation
        :return: epsilon parameter
        """
        return self._greed

    @greed.setter
    def greed(self, new_greed):
        """
        Sets the epsilon parameter balancing the exploration/exploitation
        """
        self._greed = new_greed

    @property
    def gamma_optimistic(self):
        return self._gamma_optimistic

    @gamma_optimistic.setter
    def gamma_optimistic(self, new_gamma_optimistic):
        self._gamma_optimistic = new_gamma_optimistic
        self.safety_model.gamma_measure = new_gamma_optimistic

    def get_next_action(self):
        self.has_explored = np.random.binomial(n=1, p=self.greed) == 1

        is_cautious, proba_slice, covar_slice = self.safety_model.level_set(
            self.state,
            lambda_threshold=self.lambda_cautious,
            gamma_threshold=self.gamma_cautious,
            return_proba=True,
            return_covar=True)
        is_cautious = is_cautious.squeeze()
        proba_slice = proba_slice.squeeze()
        covar_slice = covar_slice.squeeze()

        if self.has_explored:
            action = self.active_sampling_policy.get_action(
                covar_slice, is_cautious)
        else:
            q_values = self.Q_model[self.state, :]
            action = self.constrained_value_policy.get_action(
                q_values, is_cautious)

        if action is None:
            print('No viable action, taking the safest...')
            action = self.safety_maximization_policy.get_action(proba_slice)

        return action

    def get_random_safe_state(self):
        """
        Returns a random state that is classified as safe by the safety model
        :return: a safe state
        """
        is_viable = self.safety_model.measure(
            slice(None, None, None),
            lambda_threshold=self.lambda_cautious,
            gamma_threshold=self.gamma_cautious) > 0
        viable_indexes = np.atleast_1d(np.argwhere(is_viable).squeeze())
        try:
            state_index = viable_indexes[np.random.choice(len(viable_indexes))]
        except Exception as e:
            print('ERROR:', str(e))
            return None
        state = self.env.stateaction_space.state_space[state_index]
        return state

    def update_models(self, state, action, next_state, reward, failed):
        self.Q_model.update(state, action, next_state, reward, failed)
        self.safety_model.update(state, action, next_state, reward, failed)

    def fit_models(self,
                   q_train_x=None,
                   q_train_y=None,
                   q_epochs=None,
                   q_optimizer_kwargs=None,
                   s_train_x=None,
                   s_train_y=None,
                   s_epochs=None,
                   s_optimizer_kwargs=None):
        if q_train_x is not None:
            if q_optimizer_kwargs is None:
                q_optimizer_kwargs = {}
            self.Q_model.fit(q_train_x, q_train_y, q_epochs,
                             **q_optimizer_kwargs)
            if not self.keep_seed_in_data:
                self.Q_model.empty_data()

        if s_train_x is not None:
            if s_optimizer_kwargs is None:
                s_optimizer_kwargs = {}
            self.safety_model.fit(s_train_x, s_train_y, s_epochs,
                                  **s_optimizer_kwargs)
            if not self.keep_seed_in_data:
                self.safety_model.empty_data()

    def step(self):
        """
        Chooses an action according to the policy, takes a step in the Environment, and updates the models. The action
        taken is available in self.last_action.
        :return: new_state, reward, failed
        """
        old_state = self.state
        action = self.get_next_action()
        new_state, reward, failed = self.env.step(action)
        if self.has_explored or failed:
            self.safety_model.update(old_state, action, new_state, reward,
                                     failed)
        else:
            self.update_models(old_state, action, new_state, reward, failed)
        self.state = new_state
        self.last_action = action
        return new_state, reward, failed
Example #7
0
class SoftHardLearner(Agent):
    """
        Defines an Agent modelling the Q-Values with a MaternGP updated with the Q-Learning update. The Agent also has a
        model for the underlying safety measure with a SafetyModel, and the Q-Learning update is constrained to stay in the
        current estimate of the safe set. If the update chosen by Q-Learning is within a conservative estimate of the
        viable set, then the sample is used to update the safety model.
        """

    def __init__(self, env,
                 greed, step_size, discount_rate, q_x_seed, q_y_seed,
                 gamma_optimistic, gamma_hard, lambda_hard, gamma_soft, s_x_seed, s_y_seed,
                 q_gp_params=None, s_gp_params=None, keep_seed_in_data=True):
        """
        Initializer
        :param env: the environment
        :param greed: the epsilon parameter of the ConstrainedEpsilonGreedy policy
        :param q_step_size: the step size in the Q-Learning update
        :param discount_rate: the discount rate
        :param q_x_seed: the seed input of the GP for the Q-Values model
        :param q_y_seed: the seed output of the GP for the Q-Values model
        :param gamma_optimistic: the gamma parameter for Q_optimistic
        :param gamma_hard: the gamma parameter for Q_hard, the set where Q-Learning is constrained (~ Q_cautious)
        :param lambda_hard: the lambda parameter for Q_hard AND Q_soft
        :param gamma_soft: the gamma parameter for Q_soft, the set outside of which the safety measure is updated
        :param s_x_seed: the seed input of the GP for the safety model
        :param s_y_seed: the seed output of the GP for the safety model
        :param q_gp_params: the parameters defining the GP for the Q-Values model. See edge.models.inference.MaternGP
            for more information
        :param q_gp_params: the parameters defining the GP for the safety model. See edge.models.inference.MaternGP
            for more information
        :param keep_seed_in_data: whether to keep the seed data in the GPs datasets. Should be True, otherwise GPyTorch
            fails.
        """
        Q_model = GPQLearning(env.stateaction_space, step_size, discount_rate,
                              x_seed=q_x_seed, y_seed=q_y_seed,
                              gp_params=q_gp_params)
        safety_model = MaternSafety(env.stateaction_space, gamma_optimistic,
                                    x_seed=s_x_seed, y_seed=s_y_seed,
                                    gp_params=s_gp_params)
        super(SoftHardLearner, self).__init__(env, Q_model, safety_model)

        self.Q_model = Q_model
        self.safety_model = safety_model
        self.lambda_hard = lambda_hard
        self.gamma_hard = gamma_hard
        self.gamma_soft = gamma_soft
        self._gamma_optimistic = gamma_optimistic

        self.constrained_value_policy = ConstrainedEpsilonGreedy(
            self.env.stateaction_space, greed)
        self.safety_maximization_policy = SafetyMaximization(
            self.env.stateaction_space)
        self.active_sampling_policy = SafetyActiveSampling(
            self.env.stateaction_space)

        self.keep_seed_in_data = keep_seed_in_data
        if not keep_seed_in_data:
            self.Q_model.empty_data()

        self.violated_soft_constraint = None
        self.updated_safety = None

    @property
    def greed(self):
        """
        Returns the epsilon parameter of the ConstrainedEpsilonGreedy policy
        :return: epsilon parameter
        """
        return self.constrained_value_policy.greed

    @greed.setter
    def greed(self, new_greed):
        """
        Sets the epsilon parameter of the ConstrainedEpsilonGreedy policy
        """
        self.constrained_value_policy.greed = new_greed

    @property
    def step_size(self):
        """
        Returns the step_size parameter of the Q-Learning model
        :return: step_size parameter
        """
        return self.Q_model.step_size

    @step_size.setter
    def step_size(self, new_step_size):
        """
        Sets the epsilon parameter of the ConstrainedEpsilonGreedy policy
        """
        self.Q_model.step_size = new_step_size

    @property
    def gamma_optimistic(self):
        return self._gamma_optimistic

    @gamma_optimistic.setter
    def gamma_optimistic(self, new_gamma_optimistic):
        self._gamma_optimistic = new_gamma_optimistic
        self.safety_model.gamma_measure = new_gamma_optimistic

    def get_next_action(self):
        is_cautious_list, proba_slice_list = self.safety_model.level_set(
            self.state,
            lambda_threshold=[self.lambda_hard, self.lambda_hard],
            gamma_threshold=[self.gamma_hard, self.gamma_soft],
            return_proba=True,
            return_covar=False
        )
        is_cautious_hard, is_cautious_soft = is_cautious_list
        is_cautious_hard = is_cautious_hard.squeeze()
        is_cautious_soft = is_cautious_soft.squeeze()
        proba_slice_hard = proba_slice_list[0].squeeze()

        q_values = self.Q_model[self.state, :]
        action = self.constrained_value_policy.get_action(
            q_values, is_cautious_hard
        )

        if action is None:
            logger.info('No viable action, taking the safest...')
            action = self.safety_maximization_policy.get_action(proba_slice_hard)
            self.violated_soft_constraint = True
        else:
            action_idx = self.env.action_space.get_index_of(action)
            self.violated_soft_constraint = not is_cautious_soft[action_idx]

        return action

    def get_random_safe_state(self):
        """
        Returns a random state that is classified as safe by the safety model
        :return: a safe state
        """
        is_viable = self.safety_model.measure(
            slice(None, None, None),
            lambda_threshold=self.lambda_hard,
            gamma_threshold=self.gamma_hard
        ) > 0
        viable_indexes = np.atleast_1d(np.argwhere(is_viable).squeeze())
        try:
            state_index = viable_indexes[np.random.choice(len(viable_indexes))]
        except Exception as e:
            logger.error('ERROR:', str(e))
            return None
        state = self.env.stateaction_space.state_space[state_index]
        return state

    def update_models(self, state, action, next_state, reward, failed):
        self.Q_model.update(state, action, next_state, reward, failed)
        self.safety_model.update(state, action, next_state, reward, failed)

    def fit_models(self, q_train_x=None, q_train_y=None, q_epochs=None, q_optimizer_kwargs=None,
                   s_train_x=None, s_train_y=None, s_epochs=None, s_optimizer_kwargs=None):
        if q_train_x is not None:
            if q_optimizer_kwargs is None:
                q_optimizer_kwargs = {}
            self.Q_model.fit(q_train_x, q_train_y, q_epochs, **q_optimizer_kwargs)
            if not self.keep_seed_in_data:
                self.Q_model.empty_data()

        if s_train_x is not None:
            if s_optimizer_kwargs is None:
                s_optimizer_kwargs = {}
            self.safety_model.fit(s_train_x, s_train_y, s_epochs, **s_optimizer_kwargs)
            if not self.keep_seed_in_data:
                self.safety_model.empty_data()

    def step(self):
        """
        Chooses an action according to the policy, takes a step in the Environment, and updates the models. The action
        taken is available in self.last_action.
        :return: new_state, reward, failed
        """
        old_state = self.state
        action = self.get_next_action()
        new_state, reward, failed = self.env.step(action)
        if self.violated_soft_constraint or failed:
            self.Q_model.update(old_state, action, new_state, reward, failed)
            self.safety_model.update(old_state, action, new_state, reward, failed)
            self.updated_safety = True
        else:
            self.Q_model.update(old_state, action, new_state, reward, failed)
            self.updated_safety = False
        self.state = new_state
        self.last_action = action
        return new_state, reward, failed, self.env.done