Exemplo n.º 1
0
class RayPredictor(BasePredictor):
    def __init__(self, horovod_kwargs, predictor_kwargs):
        # TODO ray: investigate using Dask for prediction instead of Horovod
        setting = RayExecutor.create_settings(timeout_s=30)
        self.executor = RayExecutor(setting, **{**get_horovod_kwargs(), **horovod_kwargs})
        self.executor.start(executable_cls=RemotePredictor, executable_kwargs=predictor_kwargs)

    def batch_predict(self, model, *args, **kwargs):
        model = RayRemoteModel(model)
        results = self.executor.execute(
            lambda predictor: predictor.batch_predict(model.load(), *args, **kwargs)
        )
        return results[0]

    def batch_evaluation(self, model, *args, **kwargs):
        model = RayRemoteModel(model)
        results = self.executor.execute(
            lambda predictor: predictor.batch_evaluation(model.load(), *args, **kwargs)
        )
        return results[0]

    def batch_collect_activations(self, model, *args, **kwargs):
        model = RayRemoteModel(model)
        return self.executor.execute_single(
            lambda predictor: predictor.batch_collect_activations(model.load(), *args, **kwargs)
        )

    def shutdown(self):
        self.executor.shutdown()
Exemplo n.º 2
0
class RayLegacyTrainer(BaseTrainer):
    def __init__(self, horovod_kwargs, executable_kwargs):
        # TODO ray: make this more configurable by allowing YAML overrides of timeout_s, etc.
        setting = RayExecutor.create_settings(timeout_s=30)

        self.executor = RayExecutor(
            setting, **{
                **get_horovod_kwargs(),
                **horovod_kwargs
            })
        self.executor.start(executable_cls=RemoteTrainer,
                            executable_kwargs=executable_kwargs)

    def train(self,
              model,
              training_set,
              validation_set=None,
              test_set=None,
              **kwargs):
        workers = self.executor.driver.workers
        train_shards = training_set.pipeline().split(n=len(workers),
                                                     locality_hints=workers,
                                                     equal=True)
        val_shards = (validation_set.pipeline(
            shuffle=False).split(n=len(workers), locality_hints=workers)
                      if validation_set else None)
        test_shards = (test_set.pipeline(shuffle=False).split(
            n=len(workers), locality_hints=workers) if test_set else None)

        results = self.executor.execute(lambda trainer: legacy_train_fn(
            trainer,
            model,
            training_set.training_set_metadata,
            training_set.features,
            train_shards,
            val_shards,
            test_shards,
            **kwargs,
        ))

        return results

    def train_online(self, model, *args, **kwargs):
        results = self.executor.execute(
            lambda trainer: trainer.train_online(model, *args, **kwargs))

        return results[0]

    @property
    def validation_field(self):
        return self.executor.execute_single(
            lambda trainer: trainer.validation_field)

    @property
    def validation_metric(self):
        return self.executor.execute_single(
            lambda trainer: trainer.validation_metric)

    def shutdown(self):
        self.executor.shutdown()
Exemplo n.º 3
0
class RayTrainer(BaseTrainer):
    def __init__(self, horovod_kwargs, trainer_kwargs):
        # TODO ray: make this more configurable by allowing YAML overrides of timeout_s, etc.
        setting = RayExecutor.create_settings(timeout_s=30)
        self.executor = RayExecutor(
            setting, **{
                **get_horovod_kwargs(),
                **horovod_kwargs
            })
        self.executor.start(executable_cls=RayRemoteTrainer,
                            executable_kwargs=trainer_kwargs)

    def train(self, model, *args, **kwargs):
        remote_model = RayRemoteModel(model)
        results = self.executor.execute(lambda trainer: trainer.train(
            remote_model.load(), *args, **kwargs))

        weights, *stats = results[0]
        model.set_weights(weights)
        return (model, *stats)

    def train_online(self, model, *args, **kwargs):
        remote_model = RayRemoteModel(model)
        results = self.executor.execute(lambda trainer: trainer.train_online(
            remote_model.load(), *args, **kwargs))

        weights = results[0]
        model.set_weights(weights)
        return model

    @property
    def validation_field(self):
        return self.executor.execute_single(
            lambda trainer: trainer.validation_field)

    @property
    def validation_metric(self):
        return self.executor.execute_single(
            lambda trainer: trainer.validation_metric)

    def shutdown(self):
        self.executor.shutdown()
Exemplo n.º 4
0
class _HorovodTrainable(tune.Trainable):
    """Abstract Trainable class for Horovod."""
    # Callable function for training.
    _function = None
    # Number of hosts (nodes) to allocate per trial
    _num_hosts: int = 1
    # Number of workers (slots) to place on each host.
    _num_slots: int = 1
    # Number of CPU resources to reserve for each worker.
    _num_cpus_per_slot: int = 1
    # Whether to reserve and pass GPU resources through.
    _use_gpu: bool = False
    # bool: Whether a the function has completed training
    _finished: bool = False

    # Horovod settings
    _ssh_str: str = None
    _ssh_identity_file: str = None
    _timeout_s: int = 30

    @property
    def num_workers(self):
        return self._num_hosts * self._num_slots

    def setup(self, config: Dict):
        trainable = wrap_function(self.__class__._function)
        # We use a filelock here to ensure that the file-writing
        # process is safe across different trainables.
        if self._ssh_identity_file:
            with FileLock(self._ssh_identity_file + ".lock"):
                settings = RayExecutor.create_settings(self._timeout_s,
                                                       self._ssh_identity_file,
                                                       self._ssh_str)
        else:
            settings = RayExecutor.create_settings(self._timeout_s,
                                                   self._ssh_identity_file,
                                                   self._ssh_str)

        self.executor = RayExecutor(settings,
                                    cpus_per_slot=self._num_cpus_per_slot,
                                    use_gpu=self._use_gpu,
                                    num_hosts=self._num_hosts,
                                    num_slots=self._num_slots)

        # We can't put `self` in the lambda closure, so we
        # resolve the variable ahead of time.
        logdir_ = str(self.logdir)

        # Starts the workers as specified by the resources above.
        self.executor.start(executable_cls=trainable,
                            executable_kwargs={
                                "config":
                                config,
                                "logger_creator":
                                lambda cfg: logger_creator(cfg, logdir_)
                            })

    def step(self) -> Dict:
        if self._finished:
            raise RuntimeError("Training has already finished.")
        result = self.executor.execute(lambda w: w.step())[0]
        if RESULT_DUPLICATE in result:
            self._finished = True
        return result

    def save_checkpoint(self, checkpoint_dir: str) -> str:
        # TODO: optimize if colocated
        save_obj = self.executor.execute_single(lambda w: w.save_to_object())
        checkpoint_path = TrainableUtil.create_from_pickle(
            save_obj, checkpoint_dir)
        return checkpoint_path

    def load_checkpoint(self, checkpoint_dir: str):
        checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
        x_id = ray.put(checkpoint_obj)
        return self.executor.execute(lambda w: w.restore_from_object(x_id))

    def stop(self):
        self.executor.execute(lambda w: w.stop())
        self.executor.shutdown()