Exemple #1
0
class ConstrainedQLearner(Agent):
    """
    Defines an Agent modelling the Q-Values with a MaternGP updated with the Q-Learning update. The Agent also has a
    model for the underlying safety measure, either via a SafetyTruth or a SafetyModel, but this model is then
    not updated. The Agent then acts with a ConstrainedEpsilonGreedy policy, staying in the safe set.
    """
    def __init__(self,
                 env,
                 safety_measure,
                 greed,
                 step_size,
                 discount_rate,
                 safety_threshold,
                 x_seed,
                 y_seed,
                 gp_params=None,
                 keep_seed_in_data=True):
        """
        Initializer
        :param env: the environment
        :param safety_measure: either SafetyTruth or SafetyModel of the environment
        :param greed: the epsilon parameter of the ConstrainedEpsilonGreedy policy
        :param step_size: the step size in the Q-Learning update
        :param discount_rate: the discount rate
        :param safety_threshold: the lambda threshold used to evaluate safety. This is 0 theoretically, but an Agent
            that is at the exact boundary of the viability kernel still fails due to rounding errors. Hence, this should
            be a small, positive value.
        :param x_seed: the seed input of the GP
        :param y_seed: the seed output of the GP
        :param gp_params: the parameters defining the GP. See edge.models.inference.MaternGP for more information
        :param keep_seed_in_data: whether to keep the seed data in the GP dataset. Should be True, otherwise GPyTorch
            fails.
        """
        Q_model = GPQLearning(env.stateaction_space,
                              step_size,
                              discount_rate,
                              x_seed=x_seed,
                              y_seed=y_seed,
                              gp_params=gp_params)
        super(ConstrainedQLearner, self).__init__(env, Q_model)

        self.Q_model = Q_model
        self.safety_measure = safety_measure
        self.constrained_value_policy = ConstrainedEpsilonGreedy(
            self.env.stateaction_space, greed)
        self.safety_maximization_policy = SafetyMaximization(
            self.safety_measure.stateaction_space)
        self.safety_threshold = safety_threshold
        self.keep_seed_in_data = keep_seed_in_data
        if not keep_seed_in_data:
            self.Q_model.empty_data()

    @property
    def greed(self):
        """
        Returns the epsilon parameter of the ConstrainedEpsilonGreedy policy
        :return: epsilon parameter
        """
        return self.constrained_value_policy.greed

    @greed.setter
    def greed(self, new_greed):
        """
        Sets the epsilon parameter of the ConstrainedEpsilonGreedy policy
        """
        self.constrained_value_policy.greed = new_greed

    def get_next_action(self):
        all_actions = self.Q_model.space.action_space[:].reshape(-1, 1)
        action_is_viable = np.array([
            self.safety_measure.measure(self.state, a) > self.safety_threshold
            for a in all_actions
        ])
        q_values = self.Q_model[self.state, :]

        action = self.constrained_value_policy.get_action(
            q_values, action_is_viable)
        if action is None:
            print('No viable action, taking the safest...')
            safety_values = self.safety_measure.measure(
                self.state, slice(None, None, None))
            action = self.safety_maximization_policy.get_action(safety_values)
        if action is None:
            raise NoActionError('The agent could not find a suitable action')

        return action

    def get_random_safe_state(self):
        """
        Returns a random state that is classified as safe by the safety model
        :return: a safe state
        """
        is_viable = self.safety_measure.state_measure > 0
        viable_indexes = np.argwhere(is_viable).squeeze()
        state_index = viable_indexes[np.random.choice(len(viable_indexes))]
        state = self.safety_measure.stateaction_space.state_space[state_index]
        return state

    def update_models(self, state, action, next_state, reward, failed):
        self.Q_model.update(state, action, next_state, reward, failed)

    def fit_models(self, train_x, train_y, epochs, **optimizer_kwargs):
        self.Q_model.fit(train_x, train_y, epochs, **optimizer_kwargs)
        if not self.keep_seed_in_data:
            self.Q_model.empty_data()
class ControlledSafetyLearner(Agent):
    def __init__(self, env, s_gp_params, gamma_cautious, lambda_cautious,
                 gamma_optimistic, checks_safety=True, learn_safety=True,
                 is_free_from_safety=False, always_update_safety=False,
                 safety_model=None,
                 *models):
        self.gamma_cautious_s, self.gamma_cautious_e = gamma_cautious
        self.lambda_cautious_s, self.lambda_cautious_e = lambda_cautious
        self.gamma_optimistic_s, self.gamma_optimistic_e = gamma_optimistic
        self.gamma_cautious = self.gamma_cautious_s
        self.lambda_cautious = self.lambda_cautious_s

        if safety_model is not None:
            self.safety_model = safety_model
        else:
            x_seed = s_gp_params.pop('train_x')
            y_seed = s_gp_params.pop('train_y')
            self.safety_model = MaternSafety(
                env,
                gamma_measure=self.gamma_optimistic_s,
                x_seed=x_seed,
                y_seed=y_seed,
                gp_params=s_gp_params
            )

        super().__init__(env, self.safety_model, *models)
        self.safety_learning_policy = SafetyInformationMaximization(
            env.stateaction_space
        )
        self.safe_projection_policy = SafeProjectionPolicy(
            env.stateaction_space
        )
        self.safety_maximization_policy = SafetyMaximization(
            self.env.stateaction_space
        )
        self.active_sampling_policy = SafetyActiveSampling(
            self.env.stateaction_space
        )
        self.last_controller_action = None
        self.safety_update = None
        self.checks_safety = checks_safety
        self.followed_controller = None
        self.always_update_safety = always_update_safety
        self.violated_constraint = None
        self.is_free_from_safety = is_free_from_safety
        self.learn_safety = learn_safety

    def get_controller_action(self, *args, **kwargs):
        raise NotImplementedError

    @property
    def gamma_optimistic(self):
        return self.safety_model.gamma_measure

    @gamma_optimistic.setter
    def gamma_optimistic(self, new_gamma_optimistic):
        self.safety_model.gamma_measure = new_gamma_optimistic

    @property
    def do_safety_update(self):
        return self.learn_safety and ( True
                # self.always_update_safety
                # or self.violated_constraint
                # or (not self.followed_controller)
                # or self.failed
        )

    def update_safety_params(self, t):
        self.gamma_cautious = affine_interpolation(t, self.gamma_cautious_s,
                                                   self.gamma_cautious_e)
        self.lambda_cautious = affine_interpolation(t, self.lambda_cautious_s,
                                                    self.lambda_cautious_e)
        self.gamma_optimistic = affine_interpolation(t, self.gamma_optimistic_s,
                                                     self.gamma_optimistic_e)

    def __get_projection_with_thresholds(self, lambda_t, gamma_t,
                                         original_action):
        constraints = self.safety_model.level_set(
            self.state,
            lambda_threshold=lambda_t,
            gamma_threshold=gamma_t
        )
        projected_action = self.safe_projection_policy.get_action(
            to_project=original_action,
            constraints=constraints
        )
        return projected_action

    def __get_alternative_with_thresholds(self, lambda_t, gamma_t,
                                          maximize_safety_proba=False,
                                          use_covar_slice=False):
        alt_set, safety_proba, covar_slice, covar_matrix = \
            self.safety_model.level_set(
                self.state,
                lambda_threshold=lambda_t,
                gamma_threshold=gamma_t,
                return_proba=True,
                return_covar=True,
                return_covar_matrix=True,
            )
        if not maximize_safety_proba:
            alt_set = alt_set.squeeze()
            if alt_set.any():
                ctrlr_idx = self.env.action_space.get_index_of(
                    self.last_controller_action, around_ok=True
                )
                if use_covar_slice:
                    alternative = self.active_sampling_policy(
                        covar_slice.squeeze(), alt_set
                    )
                else:
                    alternative = self.safety_learning_policy.get_action(
                        covar_matrix[ctrlr_idx, :].squeeze(), alt_set
                    )
                return alternative
            else:
                return None
        else:
            safety_proba = safety_proba.squeeze()
            return self.safety_maximization_policy.get_action(safety_proba)

    def get_next_action(self):
        self.followed_controller = True
        self.violated_constraint = False
        self.last_controller_action = self.get_controller_action()
        action = self.last_controller_action
        if self.checks_safety:
            controller_is_cautious = self.safety_model.is_in_level_set(
                self.state, action, self.lambda_cautious, self.gamma_cautious
            )
            if not controller_is_cautious:
                if self.is_free_from_safety:
                    self.violated_constraint = True
                else:
                    # alternative = self.__get_alternative_with_thresholds(
                    #     self.lambda_cautious, self.gamma_cautious,
                    #     use_covar_slice=False
                    # )
                    alternative = self.__get_projection_with_thresholds(
                        self.lambda_cautious, self.gamma_cautious, action
                    )
                    if alternative is not None:
                        # We found a cautious alternative
                        self.violated_constraint = False
                        self.followed_controller = False
                        action = alternative
                    else:
                        self.violated_constraint = True
                        self.followed_controller = False
                        # alternative = self.__get_alternative_with_thresholds(
                        #     0., self.gamma_optimistic
                        # )
                        alternative = self.__get_projection_with_thresholds(
                            0., self.gamma_optimistic, action
                        )
                        if alternative is not None:
                            # We found an optimistic alternative
                            action = alternative
                        else:
                            # No cautious or optimistic action available:
                            # maximize safety probability
                            action = self.__get_alternative_with_thresholds(
                                0., self.gamma_optimistic,
                                maximize_safety_proba=True
                            )
        return action

    def update_models(self, state, action, next_state, reward, failed, done):
        return self.safety_model.update(state, action, next_state, reward,
                                        failed, done)

    def step(self):
        """
        Chooses an action according to the policy, takes a step in the Environment, and updates the models. The action
        taken is available in self.last_action.
        :return: new_state, reward, failed
        """
        old_state = self.state
        self.last_action = self.get_next_action()
        self.state, reward, failed = self.env.step(self.last_action)
        done = self.env.done
        if self.training_mode and self.do_safety_update:
            self.safety_update = self.update_models(
                old_state, self.last_action, self.state, reward, failed,
                done
            )
        else:
            self.safety_update = None
        return self.state, reward, failed, done
class EpsCorlLearner(Agent):
    """
        Defines an Agent modelling the Q-Values with a MaternGP updated with the Q-Learning update. The Agent also has a
        model for the underlying safety measure with a SafetyModel, and the Q-Learning update is constrained to stay in the
        current estimate of the safe set. The exploration is either the Q-Learning update is taken, or the safety measure
        update is.
        """
    def __init__(self,
                 env,
                 greed,
                 step_size,
                 discount_rate,
                 q_x_seed,
                 q_y_seed,
                 gamma_optimistic,
                 gamma_cautious,
                 lambda_cautious,
                 s_x_seed,
                 s_y_seed,
                 q_gp_params=None,
                 s_gp_params=None,
                 keep_seed_in_data=True):
        """
        Initializer
        :param env: the environment
        :param greed: the epsilon parameter of the ConstrainedEpsilonGreedy policy
        :param q_step_size: the step size in the Q-Learning update
        :param discount_rate: the discount rate
        :param q_x_seed: the seed input of the GP for the Q-Values model
        :param q_y_seed: the seed output of the GP for the Q-Values model
        :param gamma_optimistic: the gamma parameter for Q_optimistic
        :param gamma_cautious: the gamma parameter for Q_cautious
        :param lambda_cautious: the lambda parameter for Q_cautious
        :param s_x_seed: the seed input of the GP for the safety model
        :param s_y_seed: the seed output of the GP for the safety model
        :param q_gp_params: the parameters defining the GP for the Q-Values model. See edge.models.inference.MaternGP
            for more information
        :param q_gp_params: the parameters defining the GP for the safety model. See edge.models.inference.MaternGP
            for more information
        :param keep_seed_in_data: whether to keep the seed data in the GPs datasets. Should be True, otherwise GPyTorch
            fails.
        """
        Q_model = GPQLearning(env.stateaction_space,
                              step_size,
                              discount_rate,
                              x_seed=q_x_seed,
                              y_seed=q_y_seed,
                              gp_params=q_gp_params)
        safety_model = MaternSafety(env.stateaction_space,
                                    gamma_optimistic,
                                    x_seed=s_x_seed,
                                    y_seed=s_y_seed,
                                    gp_params=s_gp_params)
        super(EpsCorlLearner, self).__init__(env, Q_model, safety_model)

        self.Q_model = Q_model
        self.safety_model = safety_model
        self.lambda_cautious = lambda_cautious
        self.gamma_cautious = gamma_cautious
        self._gamma_optimistic = gamma_optimistic

        self.constrained_value_policy = ConstrainedEpsilonGreedy(
            self.env.stateaction_space, 0)
        self.safety_maximization_policy = SafetyMaximization(
            self.env.stateaction_space)
        self.active_sampling_policy = SafetyActiveSampling(
            self.env.stateaction_space)

        self.keep_seed_in_data = keep_seed_in_data
        if not keep_seed_in_data:
            self.Q_model.empty_data()

        self.has_explored = None
        self._greed = greed

    @property
    def greed(self):
        """
        Returns the epsilon parameter balancing the exploration/exploitation
        :return: epsilon parameter
        """
        return self._greed

    @greed.setter
    def greed(self, new_greed):
        """
        Sets the epsilon parameter balancing the exploration/exploitation
        """
        self._greed = new_greed

    @property
    def gamma_optimistic(self):
        return self._gamma_optimistic

    @gamma_optimistic.setter
    def gamma_optimistic(self, new_gamma_optimistic):
        self._gamma_optimistic = new_gamma_optimistic
        self.safety_model.gamma_measure = new_gamma_optimistic

    def get_next_action(self):
        self.has_explored = np.random.binomial(n=1, p=self.greed) == 1

        is_cautious, proba_slice, covar_slice = self.safety_model.level_set(
            self.state,
            lambda_threshold=self.lambda_cautious,
            gamma_threshold=self.gamma_cautious,
            return_proba=True,
            return_covar=True)
        is_cautious = is_cautious.squeeze()
        proba_slice = proba_slice.squeeze()
        covar_slice = covar_slice.squeeze()

        if self.has_explored:
            action = self.active_sampling_policy.get_action(
                covar_slice, is_cautious)
        else:
            q_values = self.Q_model[self.state, :]
            action = self.constrained_value_policy.get_action(
                q_values, is_cautious)

        if action is None:
            print('No viable action, taking the safest...')
            action = self.safety_maximization_policy.get_action(proba_slice)

        return action

    def get_random_safe_state(self):
        """
        Returns a random state that is classified as safe by the safety model
        :return: a safe state
        """
        is_viable = self.safety_model.measure(
            slice(None, None, None),
            lambda_threshold=self.lambda_cautious,
            gamma_threshold=self.gamma_cautious) > 0
        viable_indexes = np.atleast_1d(np.argwhere(is_viable).squeeze())
        try:
            state_index = viable_indexes[np.random.choice(len(viable_indexes))]
        except Exception as e:
            print('ERROR:', str(e))
            return None
        state = self.env.stateaction_space.state_space[state_index]
        return state

    def update_models(self, state, action, next_state, reward, failed):
        self.Q_model.update(state, action, next_state, reward, failed)
        self.safety_model.update(state, action, next_state, reward, failed)

    def fit_models(self,
                   q_train_x=None,
                   q_train_y=None,
                   q_epochs=None,
                   q_optimizer_kwargs=None,
                   s_train_x=None,
                   s_train_y=None,
                   s_epochs=None,
                   s_optimizer_kwargs=None):
        if q_train_x is not None:
            if q_optimizer_kwargs is None:
                q_optimizer_kwargs = {}
            self.Q_model.fit(q_train_x, q_train_y, q_epochs,
                             **q_optimizer_kwargs)
            if not self.keep_seed_in_data:
                self.Q_model.empty_data()

        if s_train_x is not None:
            if s_optimizer_kwargs is None:
                s_optimizer_kwargs = {}
            self.safety_model.fit(s_train_x, s_train_y, s_epochs,
                                  **s_optimizer_kwargs)
            if not self.keep_seed_in_data:
                self.safety_model.empty_data()

    def step(self):
        """
        Chooses an action according to the policy, takes a step in the Environment, and updates the models. The action
        taken is available in self.last_action.
        :return: new_state, reward, failed
        """
        old_state = self.state
        action = self.get_next_action()
        new_state, reward, failed = self.env.step(action)
        if self.has_explored or failed:
            self.safety_model.update(old_state, action, new_state, reward,
                                     failed)
        else:
            self.update_models(old_state, action, new_state, reward, failed)
        self.state = new_state
        self.last_action = action
        return new_state, reward, failed
class CoRLSafetyLearner(ControlledSafetyLearner):
    def __init__(self, env, s_gp_params, gamma_cautious, lambda_cautious,
                 gamma_optimistic, base_controller):
        super().__init__(env, s_gp_params, gamma_cautious, lambda_cautious,
                         gamma_optimistic, checks_safety=True,
                         learn_safety=True, is_free_from_safety=False,
                         always_update_safety=True, safety_model=None)
        self.policy = base_controller
        self.active_sampling_policy = SafetyActiveSampling(
            self.env.stateaction_space
        )
        self.safety_maximization_policy = SafetyMaximization(
            self.env.stateaction_space
        )

    def get_controller_action(self, *args, **kwargs):
        return self.policy(self.state)

    def __get_active_sampling_with_thresholds(self, lambda_t, gamma_t,
                                              maximize_safety_proba=False):
        in_set, proba_slice, covar_slice = self.safety_model.level_set(
            self.state,
            lambda_threshold=lambda_t,
            gamma_threshold=gamma_t,
            return_proba=True,
            return_covar=True
        )
        if not maximize_safety_proba:
            in_set = in_set.squeeze()
            covar_slice = covar_slice.squeeze()

            action = self.active_sampling_policy.get_action(
                covar_slice, in_set
            )
            return action
        else:
            proba_slice = proba_slice.squeeze()
            return self.safety_maximization_policy.get_action(proba_slice)

    def get_next_action(self):
        if self.training_mode:
            self.last_controller_action = self.get_controller_action()
            self.followed_controller = False
            self.violated_constraint = False
            action = self.__get_active_sampling_with_thresholds(
                self.lambda_cautious, self.gamma_cautious
            )
            if action is not None:
                self.violated_constraint = False
            else:
                self.violated_constraint = True
                action = self.__get_active_sampling_with_thresholds(
                    0., self.gamma_optimistic
                )
            if action is None:
                action = self.__get_active_sampling_with_thresholds(
                    0., self.gamma_optimistic, maximize_safety_proba=True
                )
            return action
        else:
            return super().get_next_action()
class SoftHardLearner(Agent):
    """
        Defines an Agent modelling the Q-Values with a MaternGP updated with the Q-Learning update. The Agent also has a
        model for the underlying safety measure with a SafetyModel, and the Q-Learning update is constrained to stay in the
        current estimate of the safe set. If the update chosen by Q-Learning is within a conservative estimate of the
        viable set, then the sample is used to update the safety model.
        """

    def __init__(self, env,
                 greed, step_size, discount_rate, q_x_seed, q_y_seed,
                 gamma_optimistic, gamma_hard, lambda_hard, gamma_soft, s_x_seed, s_y_seed,
                 q_gp_params=None, s_gp_params=None, keep_seed_in_data=True):
        """
        Initializer
        :param env: the environment
        :param greed: the epsilon parameter of the ConstrainedEpsilonGreedy policy
        :param q_step_size: the step size in the Q-Learning update
        :param discount_rate: the discount rate
        :param q_x_seed: the seed input of the GP for the Q-Values model
        :param q_y_seed: the seed output of the GP for the Q-Values model
        :param gamma_optimistic: the gamma parameter for Q_optimistic
        :param gamma_hard: the gamma parameter for Q_hard, the set where Q-Learning is constrained (~ Q_cautious)
        :param lambda_hard: the lambda parameter for Q_hard AND Q_soft
        :param gamma_soft: the gamma parameter for Q_soft, the set outside of which the safety measure is updated
        :param s_x_seed: the seed input of the GP for the safety model
        :param s_y_seed: the seed output of the GP for the safety model
        :param q_gp_params: the parameters defining the GP for the Q-Values model. See edge.models.inference.MaternGP
            for more information
        :param q_gp_params: the parameters defining the GP for the safety model. See edge.models.inference.MaternGP
            for more information
        :param keep_seed_in_data: whether to keep the seed data in the GPs datasets. Should be True, otherwise GPyTorch
            fails.
        """
        Q_model = GPQLearning(env.stateaction_space, step_size, discount_rate,
                              x_seed=q_x_seed, y_seed=q_y_seed,
                              gp_params=q_gp_params)
        safety_model = MaternSafety(env.stateaction_space, gamma_optimistic,
                                    x_seed=s_x_seed, y_seed=s_y_seed,
                                    gp_params=s_gp_params)
        super(SoftHardLearner, self).__init__(env, Q_model, safety_model)

        self.Q_model = Q_model
        self.safety_model = safety_model
        self.lambda_hard = lambda_hard
        self.gamma_hard = gamma_hard
        self.gamma_soft = gamma_soft
        self._gamma_optimistic = gamma_optimistic

        self.constrained_value_policy = ConstrainedEpsilonGreedy(
            self.env.stateaction_space, greed)
        self.safety_maximization_policy = SafetyMaximization(
            self.env.stateaction_space)
        self.active_sampling_policy = SafetyActiveSampling(
            self.env.stateaction_space)

        self.keep_seed_in_data = keep_seed_in_data
        if not keep_seed_in_data:
            self.Q_model.empty_data()

        self.violated_soft_constraint = None
        self.updated_safety = None

    @property
    def greed(self):
        """
        Returns the epsilon parameter of the ConstrainedEpsilonGreedy policy
        :return: epsilon parameter
        """
        return self.constrained_value_policy.greed

    @greed.setter
    def greed(self, new_greed):
        """
        Sets the epsilon parameter of the ConstrainedEpsilonGreedy policy
        """
        self.constrained_value_policy.greed = new_greed

    @property
    def step_size(self):
        """
        Returns the step_size parameter of the Q-Learning model
        :return: step_size parameter
        """
        return self.Q_model.step_size

    @step_size.setter
    def step_size(self, new_step_size):
        """
        Sets the epsilon parameter of the ConstrainedEpsilonGreedy policy
        """
        self.Q_model.step_size = new_step_size

    @property
    def gamma_optimistic(self):
        return self._gamma_optimistic

    @gamma_optimistic.setter
    def gamma_optimistic(self, new_gamma_optimistic):
        self._gamma_optimistic = new_gamma_optimistic
        self.safety_model.gamma_measure = new_gamma_optimistic

    def get_next_action(self):
        is_cautious_list, proba_slice_list = self.safety_model.level_set(
            self.state,
            lambda_threshold=[self.lambda_hard, self.lambda_hard],
            gamma_threshold=[self.gamma_hard, self.gamma_soft],
            return_proba=True,
            return_covar=False
        )
        is_cautious_hard, is_cautious_soft = is_cautious_list
        is_cautious_hard = is_cautious_hard.squeeze()
        is_cautious_soft = is_cautious_soft.squeeze()
        proba_slice_hard = proba_slice_list[0].squeeze()

        q_values = self.Q_model[self.state, :]
        action = self.constrained_value_policy.get_action(
            q_values, is_cautious_hard
        )

        if action is None:
            logger.info('No viable action, taking the safest...')
            action = self.safety_maximization_policy.get_action(proba_slice_hard)
            self.violated_soft_constraint = True
        else:
            action_idx = self.env.action_space.get_index_of(action)
            self.violated_soft_constraint = not is_cautious_soft[action_idx]

        return action

    def get_random_safe_state(self):
        """
        Returns a random state that is classified as safe by the safety model
        :return: a safe state
        """
        is_viable = self.safety_model.measure(
            slice(None, None, None),
            lambda_threshold=self.lambda_hard,
            gamma_threshold=self.gamma_hard
        ) > 0
        viable_indexes = np.atleast_1d(np.argwhere(is_viable).squeeze())
        try:
            state_index = viable_indexes[np.random.choice(len(viable_indexes))]
        except Exception as e:
            logger.error('ERROR:', str(e))
            return None
        state = self.env.stateaction_space.state_space[state_index]
        return state

    def update_models(self, state, action, next_state, reward, failed):
        self.Q_model.update(state, action, next_state, reward, failed)
        self.safety_model.update(state, action, next_state, reward, failed)

    def fit_models(self, q_train_x=None, q_train_y=None, q_epochs=None, q_optimizer_kwargs=None,
                   s_train_x=None, s_train_y=None, s_epochs=None, s_optimizer_kwargs=None):
        if q_train_x is not None:
            if q_optimizer_kwargs is None:
                q_optimizer_kwargs = {}
            self.Q_model.fit(q_train_x, q_train_y, q_epochs, **q_optimizer_kwargs)
            if not self.keep_seed_in_data:
                self.Q_model.empty_data()

        if s_train_x is not None:
            if s_optimizer_kwargs is None:
                s_optimizer_kwargs = {}
            self.safety_model.fit(s_train_x, s_train_y, s_epochs, **s_optimizer_kwargs)
            if not self.keep_seed_in_data:
                self.safety_model.empty_data()

    def step(self):
        """
        Chooses an action according to the policy, takes a step in the Environment, and updates the models. The action
        taken is available in self.last_action.
        :return: new_state, reward, failed
        """
        old_state = self.state
        action = self.get_next_action()
        new_state, reward, failed = self.env.step(action)
        if self.violated_soft_constraint or failed:
            self.Q_model.update(old_state, action, new_state, reward, failed)
            self.safety_model.update(old_state, action, new_state, reward, failed)
            self.updated_safety = True
        else:
            self.Q_model.update(old_state, action, new_state, reward, failed)
            self.updated_safety = False
        self.state = new_state
        self.last_action = action
        return new_state, reward, failed, self.env.done