コード例 #1
0
    def update_states(self, env_states, model_states, best_ix):
        """Update the data of the root state."""
        self.root_env_states.update(other=env_states)
        self.root_model_states.update(other=model_states)
        if self.accumulate_rewards:
            cum_rewards = self.root_walkers_states.cum_rewards
            cum_rewards = cum_rewards + self.root_env_states.rewards
        else:
            cum_rewards = self.root_env_states.rewards
        dt = self.root_model_states.dt if hasattr(self.root_model_states,
                                                  "dt") else 1.0
        times = dt + self.root_walker.times
        root_id = tensor(self.walkers.states.id_walkers[best_ix])
        self.root_walkers_states.update(
            cum_rewards=cum_rewards,
            times=times,
            id_walkers=tensor([root_id]),
        )

        self.root_walker = OneWalker(
            reward=judo.copy(cum_rewards[0]),
            observ=judo.copy(self.root_env_states.observs[0]),
            state=judo.copy(self.root_env_states.states[0]),
            time=judo.copy(times[0]),
            id_walker=root_id.squeeze(),
        )
コード例 #2
0
ファイル: test_swarm.py プロジェクト: Guillemdb/fragile
 def test_reset_with_root_walker(self, swarm):
     swarm.reset()
     param_dict = swarm.walkers.env_states.get_params_dict()
     obs_dict = param_dict["observs"]
     state_dict = param_dict["states"]
     obs_size = obs_dict.get("size", obs_dict["shape"][1:])
     state_size = state_dict.get("size", state_dict["shape"][1:])
     obs = judo.astype(random_state.random(obs_size), obs_dict["dtype"])
     state = judo.astype(random_state.random(state_size),
                         state_dict["dtype"])
     reward = 160290
     root_walker = OneWalker(observ=obs, reward=reward, state=state)
     swarm.reset(root_walker=root_walker)
     swarm_best_id = swarm.best_id
     root_walker_id = root_walker.id_walkers
     assert (swarm.best_state == state).all()
     assert (swarm.best_obs == obs).all(), (obs, tensor(swarm.best_obs))
     assert swarm.best_reward == reward
     assert (swarm.walkers.env_states.observs == obs).all()
     assert (swarm.walkers.env_states.states == state).all()
     assert (swarm.walkers.env_states.rewards == reward).all()
     if Backend.is_numpy():
         assert (swarm.walkers.states.id_walkers == root_walker.id_walkers
                 ).all()
         assert swarm_best_id == root_walker_id[0]
コード例 #3
0
    def reset(
        self,
        root_walker: OneWalker = None,
        walkers_states: StatesWalkers = None,
        model_states: StatesModel = None,
        env_states: StatesEnv = None,
    ):
        """
        Reset the :class:`fragile.Walkers`, the :class:`Environment`, the \
        :class:`Model` and clear the internal data to start a new search process.

        Args:
            root_walker: Walker representing the initial state of the search. \
                         The walkers will be reset to this walker, and it will \
                         be added to the root of the :class:`StateTree` if any.
            model_states: :class:`StatesModel` that define the initial state of \
                          the :class:`Model`.
            env_states: :class:`StatesEnv` that define the initial state of \
                        the :class:`Environment`.
            walkers_states: :class:`StatesWalkers` that define the internal \
                            states of the :class:`Walkers`.

        """
        self._epoch = 0
        self.internal_swarm.reset(
            root_walker=root_walker,
            walkers_states=walkers_states,
            env_states=env_states,
            model_states=model_states,
        )
        if self._use_tree:
            if root_walker is not None:
                self.tree.reset(root_hash=int(root_walker.id_walkers[0]))
            root_ids = numpy.array([self.tree.root_hash] * self.walkers.n)
            if root_walker is None:  # Otherwise the ids are already updated inside walkers.reset
                self.internal_swarm.walkers.states.id_walkers = root_ids
            self.tree.reset(
                env_states=self.internal_swarm.walkers.env_states[0],
                model_states=self.internal_swarm.walkers.model_states[0],
                walkers_states=self.internal_swarm.walkers.states[0],
            )
            ids: List[int] = [self.internal_swarm.walkers.states.id_walkers[0]]
            self.update_tree(states_ids=ids)
        # Reset root data
        self.root_model_states = self.walkers.model_states[0]
        self.root_env_states = self.walkers.env_states[0]
        self.root_walkers_states = self.walkers.states[0]
        self.root_walker = OneWalker(
            reward=self.root_env_states.rewards[0],
            observ=self.root_env_states.observs[0],
            state=self.root_env_states.states[0],
        )
コード例 #4
0
 def update_states(self):
     """Update the data of the root walker after an internal Swarm iteration has finished."""
     # The accumulation of rewards is already done in the internal Swarm
     cum_rewards = self.root_walkers_states.cum_rewards
     self.root_walkers_states.update(
         cum_rewards=cum_rewards,
         id_walkers=numpy.array([hash_numpy(self.root_env_states.states[0])]),
     )
     self.root_walker = OneWalker(
         reward=copy.deepcopy(cum_rewards[0]),
         observ=copy.deepcopy(self.root_env_states.observs[0]),
         state=copy.deepcopy(self.root_env_states.states[0]),
     )
コード例 #5
0
    def reset(
        self,
        root_walker: OneWalker = None,
        walkers_states: StatesWalkers = None,
        model_states: StatesModel = None,
        env_states: StatesEnv = None,
    ):
        """
        Reset the :class:`fragile.Walkers`, the :class:`Environment`, the \
        :class:`Model` and clear the internal data to start a new search process.

        Args:
            root_walker: Walker representing the initial state of the search. \
                         The walkers will be reset to this walker, and it will \
                         be added to the root of the :class:`StateTree` if any.
            model_states: :class:`StatesModel` that define the initial state of \
                          the :class:`Model`.
            env_states: :class:`StatesEnv` that define the initial state of \
                        the :class:`Environment`.
            walkers_states: :class:`StatesWalkers` that define the internal \
                            states of the :class:`Walkers`.

        """
        self._epoch = 0
        self.internal_swarm.reset(
            root_walker=root_walker,
            walkers_states=walkers_states,
            env_states=env_states,
            model_states=model_states,
        )
        # Reset root data
        best_index = self.walkers.get_best_index()
        self.root_model_states = self.walkers.model_states[best_index]
        self.root_env_states = self.walkers.env_states[best_index]
        self.root_walkers_states = self.walkers.states[best_index]
        self.root_walker = OneWalker(
            reward=self.root_env_states.rewards[0],
            observ=self.root_env_states.observs[0],
            state=self.root_env_states.states[0],
            time=0,
            id_walker=self.root_walkers_states.id_walkers[0],
        )
        if self.tree is not None:
            self.tree.reset(
                root_id=self.best_id,
                env_states=self.root_env_states[0],
                model_states=self.root_model_states[0],
                walkers_states=self.root_walkers_states[0],
            )
コード例 #6
0
    def update_states(self):
        """Update the data of the root state."""
        if self.accumulate_rewards:
            cum_rewards = self.root_walkers_states.cum_rewards
            cum_rewards = cum_rewards + self.root_env_states.rewards
        else:
            cum_rewards = self.root_env_states.rewards
        self.root_walkers_states.update(
            cum_rewards=cum_rewards,
            id_walkers=numpy.array(
                [hash_numpy(self.root_env_states.states[0])]),
        )

        self.root_walker = OneWalker(
            reward=copy.deepcopy(cum_rewards[0]),
            observ=copy.deepcopy(self.root_env_states.observs[0]),
            state=copy.deepcopy(self.root_env_states.states[0]),
        )
コード例 #7
0
 def update_states(self, best_ix):
     """Update the data of the root walker after an internal Swarm iteration has finished."""
     # The accumulation of rewards is already done in the internal Swarm
     cum_rewards = self.root_walkers_states.cum_rewards
     times = self.root_walkers_states.times + self.root_walker.times
     root_id = tensor(self.walkers.states.id_walkers[best_ix])
     self.root_walkers_states.update(
         cum_rewards=cum_rewards,
         id_walkers=tensor([root_id]),
         times=times,
     )
     self.root_walker = OneWalker(
         reward=judo.copy(cum_rewards[0]),
         observ=judo.copy(self.root_env_states.observs[0]),
         state=judo.copy(self.root_env_states.states[0]),
         time=judo.copy(times[0]),
         id_walker=root_id,
     )
コード例 #8
0
ファイル: test_swarm.py プロジェクト: Zeta36/fragile
 def test_reset_with_root_walker(self, swarm):
     swarm.reset()
     param_dict = swarm.walkers.env_states.get_params_dict()
     obs_dict = param_dict["observs"]
     state_dict = param_dict["states"]
     obs_size = obs_dict.get("size", obs_dict["shape"][1:])
     state_size = state_dict.get("size", state_dict["shape"][1:])
     obs = numpy.random.random(obs_size).astype(obs_dict["dtype"])
     state = numpy.random.random(state_size).astype(state_dict["dtype"])
     reward = 160290
     root_walker = OneWalker(observ=obs, reward=reward, state=state)
     swarm.reset(root_walker=root_walker)
     assert (swarm.best_obs == obs).all()
     assert (swarm.best_state == state).all()
     assert swarm.best_reward == reward
     assert swarm.best_id == root_walker.id_walkers
     assert (swarm.walkers.env_states.observs == obs).all()
     assert (swarm.walkers.env_states.states == state).all()
     assert (swarm.walkers.env_states.rewards == reward).all()
     assert (
         swarm.walkers.states.id_walkers == root_walker.id_walkers).all()
コード例 #9
0
ファイル: step_swarm.py プロジェクト: softmaxhuanchen/fragile
    def update_states(self, env_states, model_states):
        """Update the data of the root state."""
        self.root_env_states.update(other=env_states)
        self.root_model_states.update(other=model_states)
        if self.accumulate_rewards:
            cum_rewards = self.root_walkers_states.cum_rewards
            cum_rewards = cum_rewards + self.root_env_states.rewards
        else:
            cum_rewards = self.root_env_states.rewards

        times = self.root_walkers_states.times + self.root_walker.times
        self.root_walkers_states.update(
            cum_rewards=cum_rewards,
            id_walkers=numpy.array(
                [hash_numpy(self.root_env_states.states[0])]),
            times=times,
        )

        self.root_walker = OneWalker(
            reward=copy.deepcopy(cum_rewards[0]),
            observ=copy.deepcopy(self.root_env_states.observs[0]),
            state=copy.deepcopy(self.root_env_states.states[0]),
            time=copy.deepcopy(times[0]),
        )
コード例 #10
0
    def __init__(self,
                 n_walkers: int,
                 step_epochs: int = None,
                 root_model: Callable[[], RootModel] = MajorityDiscreteModel,
                 tree: Callable[[], BaseStateTree] = None,
                 prune_tree: bool = True,
                 report_interval: int = numpy.inf,
                 show_pbar: bool = True,
                 walkers: Callable[..., StepWalkers] = StepWalkers,
                 swarm: Callable[..., Swarm] = Swarm,
                 reward_limit: Scalar = None,
                 max_epochs: int = None,
                 accumulate_rewards: bool = True,
                 minimize: bool = False,
                 use_notebook_widget: bool = True,
                 *args,
                 **kwargs):
        """
        Initialize a :class:`StepSwarm`.

        This class can be initialized the same way as a :class:`Swarm`. All the \
        parameters except ``max_epochs`` and ``tree`` will be used to initialize \
        the internal swarm, and the :class:`StepSwarm` will use them when necessary.

        The internal swarm will be initialized with no tree, and with its \
        notebook widgets deactivated.

        Args:
            n_walkers: Number of walkers of the internal swarm.
            step_epochs: Number of epochs that the internal swarm will be run \
                         before sampling an action.
            root_model: Callable that returns a :class:`RootModel` that will be \
                        used to sample the actions and dt of the root walker.
            tree: Disabled for now. It will be used by the root walker.
            prune_tree: Disabled for now.
            report_interval: Display the algorithm progress every \
                            ``report_interval`` epochs.
            show_pbar: If ``True`` A progress bar will display the progress of \
                       the algorithm run.
            walkers: A callable that returns an instance of :class:`StepWalkers`.
            swarm: A callable that returns an instance of :class:`Swarm` and \
                  takes as input all the corresponding parameters provided. It \
                  will be wrapped with a :class:`StoreInitAction` before \
                  assigning it to the ``internal_swarm`` attribute.
            reward_limit: The algorithm run will stop after reaching this \
                          reward value. If you are running a minimization process \
                          it will be considered the minimum reward possible, and \
                          if you are maximizing a reward it will be the maximum \
                          value.
            max_epochs: Maximum number of steps that the root walker is allowed \
                       to take.
            accumulate_rewards: If ``True`` the rewards obtained after transitioning \
                                to a new state will accumulate. If ``False`` only the last \
                                reward will be taken into account.
            minimize: If ``True`` the algorithm will perform a minimization \
                      process. If ``False`` it will be a maximization process.
            use_notebook_widget: If ``True`` and the class is running in an IPython \
                                kernel it will display the evolution of the swarm \
                                in a widget.
            *args: Passed to ``swarm``.
            **kwargs: Passed to ``swarm``.

        """
        self.internal_swarm = StoreInitAction(
            swarm(max_epochs=step_epochs,
                  show_pbar=False,
                  report_interval=numpy.inf,
                  n_walkers=n_walkers,
                  tree=None,
                  walkers=walkers,
                  accumulate_rewards=accumulate_rewards,
                  minimize=minimize,
                  use_notebook_widget=False,
                  *args,
                  **kwargs))
        self.internal_swarm.reset()
        self.root_model: RootModel = root_model()
        if reward_limit is None:
            reward_limit = -numpy.inf if self.internal_swarm.walkers.minimize else numpy.inf
        self.accumulate_rewards = accumulate_rewards
        self._max_epochs = int(max_epochs)
        self.reward_limit = reward_limit
        self.show_pbar = show_pbar
        self.report_interval = report_interval
        self.tree = tree() if tree is not None else tree
        self._prune_tree = prune_tree
        self._use_tree = tree is not None
        self._epoch = 0
        self._walkers: StepWalkers = self.internal_swarm.walkers
        self._model = self.internal_swarm.model
        self._env = self.internal_swarm.env
        self.cum_reward = numpy.NINF
        self.minimize = minimize
        self.root_model_states = self.walkers.model_states[0]
        self.root_env_states = self.walkers.env_states[0]
        self.root_walkers_states = self.walkers.states[0]
        self.root_walker = OneWalker(
            reward=self.root_env_states.rewards[0],
            observ=self.root_env_states.observs[0],
            state=self.root_env_states.states[0],
        )
        self._notebook_container = None
        self._use_notebook_widget = use_notebook_widget
        self.setup_notebook_container()