Пример #1
0
    def setup(self, config):
        self._finished = False
        num_workers = self._num_workers
        logdir = self.logdir
        assert self._function

        func_trainable = wrap_function(self.__class__._function)

        remote_trainable = ray.remote(func_trainable)
        remote_trainable = remote_trainable.options(
            **self.get_remote_worker_options())

        address = setup_address()
        self.workers = [
            remote_trainable.remote(
                config=config,
                logger_creator=lambda cfg: logger_creator(cfg, logdir, rank))
            for rank in range(num_workers)
        ]

        pgroup_params = self.default_process_group_parameters()
        from functools import partial
        setup_on_worker = partial(setup_process_group,
                                  url=address,
                                  world_size=num_workers,
                                  **pgroup_params)
        ray.get([
            w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
            for rank, w in enumerate(self.workers)
        ])

        ray.get([
            w.execute.remote(lambda _: enable_distributed_trainable())
            for rank, w in enumerate(self.workers)
        ])
Пример #2
0
    def setup(self, config: Dict):
        self._finished = False
        num_workers = self._num_workers
        logdir = self.logdir
        assert self._function

        func_trainable = wrap_function(self.__class__._function)

        remote_trainable = ray.remote(func_trainable)
        (
            remote_option,
            self._placement_group,
        ) = PlacementGroupUtil.get_remote_worker_options(
            self._num_workers,
            self._num_cpus_per_worker,
            self._num_gpus_per_worker,
            self._num_workers_per_host,
            self._timeout_s,
        )
        remote_trainable = remote_trainable.options(**remote_option)
        new_config = DistributedTrainable.build_config(self, config)

        self.workers = [
            remote_trainable.remote(
                config=new_config,
                logger_creator=lambda cfg: logger_creator(cfg, logdir, rank),
            )
            for rank in range(num_workers)
        ]

        # Address has to be IP of rank 0 worker's node.
        address = ray.get(self.workers[0].execute.remote(lambda _: setup_address()))

        pgroup_params = self.default_process_group_parameters()
        from functools import partial

        setup_on_worker = partial(
            setup_process_group, url=address, world_size=num_workers, **pgroup_params
        )
        ray.get(
            [
                w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
                for rank, w in enumerate(self.workers)
            ]
        )

        ray.get(
            [
                w.execute.remote(lambda _: enable_distributed_trainable())
                for rank, w in enumerate(self.workers)
            ]
        )
Пример #3
0
    def start_workers(self, num_workers):
        logger.debug(f"start_workers: Setting {num_workers} workers.")

        if num_workers == 1:
            self.local_worker = TorchRunner(**self._params)
            if self._initialization_hook:
                self.apply_all_workers(self._initialization_hook)
            self.local_worker.setup_operator()
            return True
        else:
            try:
                # Start local worker
                self.local_worker = LocalDistributedRunner(
                    num_cpus=self._num_cpus_per_worker,
                    num_gpus=int(self._use_gpu),
                    **{
                        **self._params,
                        **self._dist_params
                    },
                )
                self.remote_worker_group._init_dist_workers(num_workers - 1)
                if self._initialization_hook:
                    self.apply_all_workers(self._initialization_hook)

                # Compute URL for initializing distributed PyTorch.
                address = setup_address()

                remote_pgs = self.remote_worker_group._setup_process_group(
                    address=address, world_size=num_workers, starting_rank=1)
                # Use the local worker as rank 0. Helps with debugging.
                self.local_worker.setup_process_group(
                    url=address,
                    world_rank=0,
                    world_size=num_workers,
                    timeout=timedelta(seconds=self._timeout_s),
                )
                ray.get(remote_pgs)

                local_node_ip = ray.util.get_node_ip_address()
                rank_dict = defaultdict(int)
                self.local_worker.set_local_rank(local_rank=0)
                rank_dict[local_node_ip] += 1
                self.remote_worker_group._setup_local_rank(rank_dict)

                remote_operators = self.remote_worker_group._setup_operator()
                self.local_worker.setup_operator()
                ray.get(remote_operators)
                return True
            except RayActorError:
                return False
Пример #4
0
    def start_workers(self, num_workers):
        logger.debug(f"start_workers: Setting %d workers." % num_workers)

        if num_workers == 1:
            self.local_worker = TorchRunner(**self._params)
            if self._initialization_hook:
                self.apply_all_workers(self._initialization_hook)
            self.local_worker.setup_operator()
        else:

            # Start local worker
            self.local_worker = LocalDistributedRunner(
                num_cpus=self._num_cpus_per_worker,
                num_gpus=int(self._use_gpu),
                **{
                    **self._params,
                    **self._dist_params
                })
            self.remote_worker_group._init_dist_workers(num_workers - 1)
            if self._initialization_hook:
                self.apply_all_workers(self._initialization_hook)

            # Compute URL for initializing distributed PyTorch.
            address = setup_address()

            remote_pgs = self.remote_worker_group._setup_process_group(
                address=address, world_size=num_workers, starting_rank=1)
            # Use the local worker as rank 0. This will help with debugging.
            self.local_worker.setup_process_group(url=address,
                                                  world_rank=0,
                                                  world_size=num_workers,
                                                  timeout=timedelta(
                                                      self._timeout_s))
            ray.get(remote_pgs)

            remote_operators = self.remote_worker_group._setup_operator()
            self.local_worker.setup_operator()
            ray.get(remote_operators)
Пример #5
0
 def setup_address(self):
     return setup_address()
Пример #6
0
    def _start_workers(self, num_workers):
        logger.debug(f"start_workers: Setting %d workers." % num_workers)
        worker_config = self.config.copy()
        batch_size_per_worker = self._configure_and_split_batch(num_workers)
        if batch_size_per_worker:
            worker_config[BATCH_SIZE] = batch_size_per_worker

        params = dict(
            training_operator_cls=self.training_operator_cls,
            config=worker_config,
            serialize_data_creation=self.serialize_data_creation,
            use_fp16=self.use_fp16,
            use_gpu=self.use_gpu,
            use_tqdm=self.use_tqdm,
            apex_args=self.apex_args,
            scheduler_step_freq=self.scheduler_step_freq)

        if num_workers == 1:
            # Start local worker
            self.local_worker = TorchRunner(**params)
            if self.initialization_hook:
                self.apply_all_workers(self.initialization_hook)
            self.local_worker.setup_operator()
        else:
            params.update(
                backend=self.backend,
                add_dist_sampler=self.add_dist_sampler,
                wrap_ddp=self.wrap_ddp)

            # Start local worker
            self.local_worker = LocalDistributedRunner(
                num_cpus=self.num_cpus_per_worker,
                num_gpus=int(self.use_gpu),
                **params)

            # Generate actor class
            RemoteRunner = ray.remote(
                num_cpus=self.num_cpus_per_worker,
                num_gpus=int(self.use_gpu))(DistributedTorchRunner)
            # Start workers
            self.remote_workers = [
                RemoteRunner.remote(**params) for i in range(num_workers - 1)
            ]
            if self.initialization_hook:
                self.apply_all_workers(self.initialization_hook)

            # Compute URL for initializing distributed PyTorch
            address = setup_address()

            # Setup the process group among all workers.
            remote_pgroup_setups = [
                worker.setup_process_group.remote(address, i + 1, num_workers,
                                                  timedelta(self.timeout_s))
                for i, worker in enumerate(self.remote_workers)
            ]
            self.local_worker.setup_process_group(address, 0, num_workers,
                                                  timedelta(self.timeout_s))
            # Get setup tasks in order to throw errors on failure
            ray.get(remote_pgroup_setups)

            # Runs code that requires all creator functions to have run.
            remote_operator_setups = [
                worker.setup_operator.remote()
                for worker in self.remote_workers
            ]
            self.local_worker.setup_operator()
            # Get setup tasks in order to throw errors on failure
            ray.get(remote_operator_setups)