class RayPredictor(BasePredictor): def __init__(self, horovod_kwargs, predictor_kwargs): # TODO ray: investigate using Dask for prediction instead of Horovod setting = RayExecutor.create_settings(timeout_s=30) self.executor = RayExecutor(setting, **{**get_horovod_kwargs(), **horovod_kwargs}) self.executor.start(executable_cls=RemotePredictor, executable_kwargs=predictor_kwargs) def batch_predict(self, model, *args, **kwargs): model = RayRemoteModel(model) results = self.executor.execute( lambda predictor: predictor.batch_predict(model.load(), *args, **kwargs) ) return results[0] def batch_evaluation(self, model, *args, **kwargs): model = RayRemoteModel(model) results = self.executor.execute( lambda predictor: predictor.batch_evaluation(model.load(), *args, **kwargs) ) return results[0] def batch_collect_activations(self, model, *args, **kwargs): model = RayRemoteModel(model) return self.executor.execute_single( lambda predictor: predictor.batch_collect_activations(model.load(), *args, **kwargs) ) def shutdown(self): self.executor.shutdown()
class RayLegacyTrainer(BaseTrainer): def __init__(self, horovod_kwargs, executable_kwargs): # TODO ray: make this more configurable by allowing YAML overrides of timeout_s, etc. setting = RayExecutor.create_settings(timeout_s=30) self.executor = RayExecutor( setting, **{ **get_horovod_kwargs(), **horovod_kwargs }) self.executor.start(executable_cls=RemoteTrainer, executable_kwargs=executable_kwargs) def train(self, model, training_set, validation_set=None, test_set=None, **kwargs): workers = self.executor.driver.workers train_shards = training_set.pipeline().split(n=len(workers), locality_hints=workers, equal=True) val_shards = (validation_set.pipeline( shuffle=False).split(n=len(workers), locality_hints=workers) if validation_set else None) test_shards = (test_set.pipeline(shuffle=False).split( n=len(workers), locality_hints=workers) if test_set else None) results = self.executor.execute(lambda trainer: legacy_train_fn( trainer, model, training_set.training_set_metadata, training_set.features, train_shards, val_shards, test_shards, **kwargs, )) return results def train_online(self, model, *args, **kwargs): results = self.executor.execute( lambda trainer: trainer.train_online(model, *args, **kwargs)) return results[0] @property def validation_field(self): return self.executor.execute_single( lambda trainer: trainer.validation_field) @property def validation_metric(self): return self.executor.execute_single( lambda trainer: trainer.validation_metric) def shutdown(self): self.executor.shutdown()
class RayTrainer(BaseTrainer): def __init__(self, horovod_kwargs, trainer_kwargs): # TODO ray: make this more configurable by allowing YAML overrides of timeout_s, etc. setting = RayExecutor.create_settings(timeout_s=30) self.executor = RayExecutor( setting, **{ **get_horovod_kwargs(), **horovod_kwargs }) self.executor.start(executable_cls=RayRemoteTrainer, executable_kwargs=trainer_kwargs) def train(self, model, *args, **kwargs): remote_model = RayRemoteModel(model) results = self.executor.execute(lambda trainer: trainer.train( remote_model.load(), *args, **kwargs)) weights, *stats = results[0] model.set_weights(weights) return (model, *stats) def train_online(self, model, *args, **kwargs): remote_model = RayRemoteModel(model) results = self.executor.execute(lambda trainer: trainer.train_online( remote_model.load(), *args, **kwargs)) weights = results[0] model.set_weights(weights) return model @property def validation_field(self): return self.executor.execute_single( lambda trainer: trainer.validation_field) @property def validation_metric(self): return self.executor.execute_single( lambda trainer: trainer.validation_metric) def shutdown(self): self.executor.shutdown()
class _HorovodTrainable(tune.Trainable): """Abstract Trainable class for Horovod.""" # Callable function for training. _function = None # Number of hosts (nodes) to allocate per trial _num_hosts: int = 1 # Number of workers (slots) to place on each host. _num_slots: int = 1 # Number of CPU resources to reserve for each worker. _num_cpus_per_slot: int = 1 # Whether to reserve and pass GPU resources through. _use_gpu: bool = False # bool: Whether a the function has completed training _finished: bool = False # Horovod settings _ssh_str: str = None _ssh_identity_file: str = None _timeout_s: int = 30 @property def num_workers(self): return self._num_hosts * self._num_slots def setup(self, config: Dict): trainable = wrap_function(self.__class__._function) # We use a filelock here to ensure that the file-writing # process is safe across different trainables. if self._ssh_identity_file: with FileLock(self._ssh_identity_file + ".lock"): settings = RayExecutor.create_settings(self._timeout_s, self._ssh_identity_file, self._ssh_str) else: settings = RayExecutor.create_settings(self._timeout_s, self._ssh_identity_file, self._ssh_str) self.executor = RayExecutor(settings, cpus_per_slot=self._num_cpus_per_slot, use_gpu=self._use_gpu, num_hosts=self._num_hosts, num_slots=self._num_slots) # We can't put `self` in the lambda closure, so we # resolve the variable ahead of time. logdir_ = str(self.logdir) # Starts the workers as specified by the resources above. self.executor.start(executable_cls=trainable, executable_kwargs={ "config": config, "logger_creator": lambda cfg: logger_creator(cfg, logdir_) }) def step(self) -> Dict: if self._finished: raise RuntimeError("Training has already finished.") result = self.executor.execute(lambda w: w.step())[0] if RESULT_DUPLICATE in result: self._finished = True return result def save_checkpoint(self, checkpoint_dir: str) -> str: # TODO: optimize if colocated save_obj = self.executor.execute_single(lambda w: w.save_to_object()) checkpoint_path = TrainableUtil.create_from_pickle( save_obj, checkpoint_dir) return checkpoint_path def load_checkpoint(self, checkpoint_dir: str): checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir) x_id = ray.put(checkpoint_obj) return self.executor.execute(lambda w: w.restore_from_object(x_id)) def stop(self): self.executor.execute(lambda w: w.stop()) self.executor.shutdown()