Esempio n. 1
0
    def setup(self, config: Dict):
        self._finished = False
        num_workers = self._num_workers
        assert self._function

        func_trainable = wrap_function(self.__class__._function)
        remote_trainable = ray.remote(func_trainable)
        remote_option, self._placement_group =\
            PlacementGroupUtil.get_remote_worker_options(
                self._num_workers, self._num_cpus_per_worker,
                self._num_gpus_per_worker,
                self._num_workers_per_host, self._timeout_s)
        remote_trainable = \
            remote_trainable.options(**remote_option)
        new_config = DistributedTrainable.build_config(self, config)
        self.workers = [
            remote_trainable.remote(config=new_config, )
            for _ in range(num_workers)
        ]

        addresses = [
            ray.get(worker.execute.remote(lambda _: setup_address()))
            for worker in self.workers
        ]

        from functools import partial
        setup_on_worker = partial(setup_process_group,
                                  worker_addresses=addresses)
        ray.get([
            w.execute.remote(lambda _: setup_on_worker(index=index))
            for index, w in enumerate(self.workers)
        ])
Esempio n. 2
0
    def setup(self, config: Dict):
        self._finished = False
        num_workers = self._num_workers
        logdir = self.logdir
        assert self._function

        func_trainable = wrap_function(self.__class__._function)

        remote_trainable = ray.remote(func_trainable)
        (
            remote_option,
            self._placement_group,
        ) = PlacementGroupUtil.get_remote_worker_options(
            self._num_workers,
            self._num_cpus_per_worker,
            self._num_gpus_per_worker,
            self._num_workers_per_host,
            self._timeout_s,
        )
        remote_trainable = remote_trainable.options(**remote_option)
        new_config = DistributedTrainable.build_config(self, config)

        self.workers = [
            remote_trainable.remote(
                config=new_config,
                logger_creator=lambda cfg: logger_creator(cfg, logdir, rank),
            )
            for rank in range(num_workers)
        ]

        # Address has to be IP of rank 0 worker's node.
        address = ray.get(self.workers[0].execute.remote(lambda _: setup_address()))

        pgroup_params = self.default_process_group_parameters()
        from functools import partial

        setup_on_worker = partial(
            setup_process_group, url=address, world_size=num_workers, **pgroup_params
        )
        ray.get(
            [
                w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
                for rank, w in enumerate(self.workers)
            ]
        )

        ray.get(
            [
                w.execute.remote(lambda _: enable_distributed_trainable())
                for rank, w in enumerate(self.workers)
            ]
        )