def reset(self, batch_size: int = 1, env_states: StatesEnv = None, *args, **kwargs) -> StatesEnv: """ Reset the environment to the start of a new episode and returns a new \ States instance describing the state of the Environment. Args: batch_size: Number of walkers that the returned state will have. env_states: :class:`StatesEnv` representing the data to be set in \ the environment. *args: Passed to the internal environment ``reset``. **kwargs: Passed to the internal environment ``reset``. Returns: States instance describing the state of the Environment. The first \ dimension of the data tensors (number of walkers) will be equal to \ batch_size. """ reset = [ env.reset.remote(batch_size=batch_size, env_states=env_states, *args, **kwargs) for env in self.envs ] ray.get(reset) return self._local_env.reset(batch_size=batch_size, env_states=env_states, *args, **kwargs)
def reset(self, root_walker: OneWalker = None): """Reset the internal data of the swarms and parameter server.""" self._epoch = 0 reset_param_server = self.param_server.reset.remote() reset_swarms = [swarm.reset.remote(root_walker=root_walker) for swarm in self.swarms] ray.get(reset_param_server) ray.get(reset_swarms)
def __init__( self, swarm: Callable, n_swarms: 2, n_import: int = 2, n_export: int = 2, export_best: bool = True, import_best: bool = True, max_len: int = 20, add_global_best: bool = True, swarm_kwargs: dict = None, report_interval: int = numpy.inf, ): """ Initialize a :class:`DistributedExport`. Args: swarm: Callable that returns a :class:`Swarm`. Accepts keyword \ arguments defined in ``swarm_kwargs``. n_swarms: Number of :class:`ExportSwarm` that will be used in the \ to run the search process. n_import: Number of walkers that will be imported from an external \ :class:`ExportedWalkers`. n_export: Number of walkers that will be exported as :class:`ExportedWalkers`. export_best: The best walkers of the :class:`Swarm` will always be exported. import_best: The best walker of the imported :class:`ExportedWalkers` \ will be compared to the best walkers of the \ :class:`Swarm`. If it improves the current best value \ found, the best walker of the :class:`Swarm` will be updated. max_len: Maximum number of :class:`ExportedWalkers` that the \ :class:`ParamServer` will keep in its buffer. add_global_best: Add the best value found during the search to all \ the exported walkers that the :class:`ParamServer` \ returns. swarm_kwargs: Dictionary containing keyword that will be passed to ``swarm``. report_interval: Display the algorithm progress every ``log_interval`` epochs. """ self.report_interval = report_interval self.swarms = [ RemoteExportSwarm.remote( swarm=swarm, n_export=n_export, n_import=n_import, import_best=import_best, export_best=export_best, swarm_kwargs=swarm_kwargs, ) for _ in range(n_swarms) ] self.n_swarms = n_swarms self.minimize = ray.get(self.swarms[0].get.remote("minimize")) self.max_epochs = ray.get(self.swarms[0].get.remote("max_epochs")) self.reward_limit = ray.get(self.swarms[0].get.remote("reward_limit")) self.param_server = RemoteParamServer.remote( max_len=max_len, minimize=self.minimize, add_global_best=add_global_best) self._epoch = 0
def step(self, model_states: StatesModel, env_states: StatesEnv) -> StatesEnv: """ Set the environment to the target states by applying the specified \ actions an arbitrary number of time steps. The state transitions will be calculated in parallel. Args: model_states: :class:`StatesModel` representing the data to be used \ to act on the environment. env_states: :class:`StatesEnv` representing the data to be set in \ the environment. Returns: :class:`StatesEnv` containing the information that describes the \ new state of the Environment. """ split_env_states = [ env.step.remote(model_states=ms, env_states=es) for env, ms, es in zip( self.envs, model_states.split_states(self.n_workers), env_states.split_states(self.n_workers), ) ] env_states = ray.get(split_env_states) new_env_states: StatesEnv = StatesEnv.merge_states(env_states) return new_env_states
def get_params_dict(self) -> StateDict: """ Return a dictionary containing the param_dict to build an instance \ of :class:`StatesEnv` that can handle all the data generated by an \ :class:`Environment`. """ params = self.envs[0].get_params_dict.remote() return ray.get(params)
def _make_transitions(self, split_results): results = [ env.make_transitions.remote(**chunk) if self.kwargs_mode else env.make_transitions.remote(*chunk) for env, chunk in zip(self.envs, split_results) ] data_dicts = ray.get(results) return data_dicts
def _make_transitions(self, chunk_data): from fragile.distributed.ray import ray results = [ env.make_transitions.remote(**chunk) if self.kwargs_mode else env.make_transitions.remote(*chunk) for env, chunk in zip(self.envs, chunk_data) ] data_dicts = ray.get(results) return data_dicts
def distribute(self, name, **kwargs): """Execute the target function in all the different workers.""" chunk_data = self._split_inputs_in_chunks(**kwargs) from fragile.distributed.ray import ray results = [ env.execute.remote(name=name, **chunk) for env, chunk in zip(self.envs, chunk_data) ] split_results = ray.get(results) if isinstance(split_results[0], dict): merged = self._merge_data(split_results) else: # Assumes batch of tensors split_results = [judo.to_backend(res) for res in split_results] merged = judo.concatenate(split_results) return merged
def function(self, points: numpy.ndarray) -> numpy.ndarray: """ Run the target :class:`Function` function in parallel. Args: points: Array of batched points to be evaluated. The length of the \ first dimension or ``points`` equals the batch size. Returns: Rewards associated with each point. """ rewards_ids = [ env.function.remote(batch) for env, batch in zip( self.envs, split_similar_chunks(points, self.n_workers)) ] split_rewards = ray.get(rewards_ids) return numpy.concatenate(split_rewards)
def observs_shape(self) -> tuple: """Return the shape of the observations state of the :class:`Environment`.""" shape = self.envs[0].get.remote("observs_shape") return ray.get(shape)
def states_shape(self) -> tuple: """Return the shape of the internal state of the :class:`Environment`.""" shape = self.envs[0].get.remote("states_shape") return ray.get(shape)
def get_best(self) -> BestWalker: """Return the best walkers found during the algorithm run.""" return ray.get(self.param_server.get.remote("best"))
def __getattr__(self, item): return ray.get(self.swarms[0].get.remote(item))