def setup(self, config: Dict): trainable = wrap_function(self.__class__._function) # We use a filelock here to ensure that the file-writing # process is safe across different trainables. if self._ssh_identity_file: with FileLock(self._ssh_identity_file + ".lock"): settings = RayExecutor.create_settings(self._timeout_s, self._ssh_identity_file, self._ssh_str) else: settings = RayExecutor.create_settings(self._timeout_s, self._ssh_identity_file, self._ssh_str) self.executor = RayExecutor( settings, cpus_per_worker=self._num_cpus_per_worker, use_gpu=self._use_gpu, num_workers=self._num_workers, ) new_config = DistributedTrainable.build_config(self, config) # We can't put `self` in the lambda closure, so we # resolve the variable ahead of time. logdir_ = str(self.logdir) # Starts the workers as specified by the resources above. self.executor.start( executable_cls=trainable, executable_kwargs={ "config": new_config, "logger_creator": lambda cfg: logger_creator(cfg, logdir_), }, )
def setup(self, config: Dict): self._finished = False num_workers = self._num_workers assert self._function func_trainable = wrap_function(self.__class__._function) remote_trainable = ray.remote(func_trainable) remote_option, self._placement_group =\ PlacementGroupUtil.get_remote_worker_options( self._num_workers, self._num_cpus_per_worker, self._num_gpus_per_worker, self._num_workers_per_host, self._timeout_s) remote_trainable = \ remote_trainable.options(**remote_option) new_config = DistributedTrainable.build_config(self, config) self.workers = [ remote_trainable.remote(config=new_config, ) for _ in range(num_workers) ] addresses = [ ray.get(worker.execute.remote(lambda _: setup_address())) for worker in self.workers ] from functools import partial setup_on_worker = partial(setup_process_group, worker_addresses=addresses) ray.get([ w.execute.remote(lambda _: setup_on_worker(index=index)) for index, w in enumerate(self.workers) ])
def setup(self, config: Dict): self._finished = False num_workers = self._num_workers logdir = self.logdir assert self._function func_trainable = wrap_function(self.__class__._function) remote_trainable = ray.remote(func_trainable) ( remote_option, self._placement_group, ) = PlacementGroupUtil.get_remote_worker_options( self._num_workers, self._num_cpus_per_worker, self._num_gpus_per_worker, self._num_workers_per_host, self._timeout_s, ) remote_trainable = remote_trainable.options(**remote_option) new_config = DistributedTrainable.build_config(self, config) self.workers = [ remote_trainable.remote( config=new_config, logger_creator=lambda cfg: logger_creator(cfg, logdir, rank), ) for rank in range(num_workers) ] # Address has to be IP of rank 0 worker's node. address = ray.get(self.workers[0].execute.remote(lambda _: setup_address())) pgroup_params = self.default_process_group_parameters() from functools import partial setup_on_worker = partial( setup_process_group, url=address, world_size=num_workers, **pgroup_params ) ray.get( [ w.execute.remote(lambda _: setup_on_worker(world_rank=rank)) for rank, w in enumerate(self.workers) ] ) ray.get( [ w.execute.remote(lambda _: enable_distributed_trainable()) for rank, w in enumerate(self.workers) ] )