Пример #1
0
    def setup(self, config: Dict):
        trainable = wrap_function(self.__class__._function)
        # We use a filelock here to ensure that the file-writing
        # process is safe across different trainables.
        if self._ssh_identity_file:
            with FileLock(self._ssh_identity_file + ".lock"):
                settings = RayExecutor.create_settings(self._timeout_s,
                                                       self._ssh_identity_file,
                                                       self._ssh_str)
        else:
            settings = RayExecutor.create_settings(self._timeout_s,
                                                   self._ssh_identity_file,
                                                   self._ssh_str)

        self.executor = RayExecutor(
            settings,
            cpus_per_worker=self._num_cpus_per_worker,
            use_gpu=self._use_gpu,
            num_workers=self._num_workers,
        )

        new_config = DistributedTrainable.build_config(self, config)

        # We can't put `self` in the lambda closure, so we
        # resolve the variable ahead of time.
        logdir_ = str(self.logdir)

        # Starts the workers as specified by the resources above.
        self.executor.start(
            executable_cls=trainable,
            executable_kwargs={
                "config": new_config,
                "logger_creator": lambda cfg: logger_creator(cfg, logdir_),
            },
        )
Пример #2
0
    def setup(self, config: Dict):
        self._finished = False
        num_workers = self._num_workers
        assert self._function

        func_trainable = wrap_function(self.__class__._function)
        remote_trainable = ray.remote(func_trainable)
        remote_option, self._placement_group =\
            PlacementGroupUtil.get_remote_worker_options(
                self._num_workers, self._num_cpus_per_worker,
                self._num_gpus_per_worker,
                self._num_workers_per_host, self._timeout_s)
        remote_trainable = \
            remote_trainable.options(**remote_option)
        new_config = DistributedTrainable.build_config(self, config)
        self.workers = [
            remote_trainable.remote(config=new_config, )
            for _ in range(num_workers)
        ]

        addresses = [
            ray.get(worker.execute.remote(lambda _: setup_address()))
            for worker in self.workers
        ]

        from functools import partial
        setup_on_worker = partial(setup_process_group,
                                  worker_addresses=addresses)
        ray.get([
            w.execute.remote(lambda _: setup_on_worker(index=index))
            for index, w in enumerate(self.workers)
        ])
Пример #3
0
    def setup(self, config: Dict):
        self._finished = False
        num_workers = self._num_workers
        logdir = self.logdir
        assert self._function

        func_trainable = wrap_function(self.__class__._function)

        remote_trainable = ray.remote(func_trainable)
        (
            remote_option,
            self._placement_group,
        ) = PlacementGroupUtil.get_remote_worker_options(
            self._num_workers,
            self._num_cpus_per_worker,
            self._num_gpus_per_worker,
            self._num_workers_per_host,
            self._timeout_s,
        )
        remote_trainable = remote_trainable.options(**remote_option)
        new_config = DistributedTrainable.build_config(self, config)

        self.workers = [
            remote_trainable.remote(
                config=new_config,
                logger_creator=lambda cfg: logger_creator(cfg, logdir, rank),
            )
            for rank in range(num_workers)
        ]

        # Address has to be IP of rank 0 worker's node.
        address = ray.get(self.workers[0].execute.remote(lambda _: setup_address()))

        pgroup_params = self.default_process_group_parameters()
        from functools import partial

        setup_on_worker = partial(
            setup_process_group, url=address, world_size=num_workers, **pgroup_params
        )
        ray.get(
            [
                w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
                for rank, w in enumerate(self.workers)
            ]
        )

        ray.get(
            [
                w.execute.remote(lambda _: enable_distributed_trainable())
                for rank, w in enumerate(self.workers)
            ]
        )