def __init__( self, *, model_creator, optimizer_creator, loss_creator=None, metrics=None, scheduler_creator=None, training_operator_cls=TrainingOperator, initialization_hook=None, config=None, scheduler_step_freq="batch", use_tqdm=False, backend="torch_distributed", workers_per_node=1): # todo remove ray_ctx to run on workers ray_ctx = RayContext.get() if not (isinstance(model_creator, types.FunctionType) and isinstance(optimizer_creator, types.FunctionType)): # Torch model is also callable. raise ValueError( "Must provide a function for both model_creator and optimizer_creator") self.model_creator = model_creator self.optimizer_creator = optimizer_creator self.loss_creator = loss_creator self.scheduler_creator = scheduler_creator self.training_operator_cls = training_operator_cls self.scheduler_step_freq = scheduler_step_freq self.use_tqdm = use_tqdm if not training_operator_cls and not loss_creator: raise ValueError("If a loss_creator is not provided, you must " "provide a custom training operator.") self.initialization_hook = initialization_hook self.config = {} if config is None else config worker_config = self.config.copy() params = dict( model_creator=self.model_creator, optimizer_creator=self.optimizer_creator, loss_creator=self.loss_creator, scheduler_creator=self.scheduler_creator, training_operator_cls=self.training_operator_cls, scheduler_step_freq=self.scheduler_step_freq, use_tqdm=self.use_tqdm, config=worker_config, metrics=metrics ) if backend == "torch_distributed": cores_per_node = ray_ctx.ray_node_cpu_cores // workers_per_node num_nodes = ray_ctx.num_ray_nodes * workers_per_node RemoteRunner = ray.remote(num_cpus=cores_per_node)(TorchRunner) self.remote_workers = [ RemoteRunner.remote(**params) for i in range(num_nodes) ] ray.get([ worker.setup.remote(cores_per_node) for i, worker in enumerate(self.remote_workers) ]) head_worker = self.remote_workers[0] address = ray.get(head_worker.setup_address.remote()) logger.info(f"initializing pytorch process group on {address}") ray.get([ worker.setup_torch_distribute.remote(address, i, num_nodes) for i, worker in enumerate(self.remote_workers) ]) elif backend == "horovod": from zoo.orca.learn.horovod.horovod_ray_runner import HorovodRayRunner self.horovod_runner = HorovodRayRunner(ray_ctx, worker_cls=TorchRunner, worker_param=params, workers_per_node=workers_per_node) self.remote_workers = self.horovod_runner.remote_workers cores_per_node = self.horovod_runner.cores_per_node ray.get([ worker.setup.remote(cores_per_node) for i, worker in enumerate(self.remote_workers) ]) ray.get([ worker.setup_horovod.remote() for i, worker in enumerate(self.remote_workers) ]) else: raise Exception("Only \"torch_distributed\" and \"horovod\" are supported " "values of backend, but got {}".format(backend)) self.num_workers = len(self.remote_workers)
def __init__(self, model_creator, compile_args_creator=None, config=None, verbose=False, backend="tf2", workers_per_node=1): self.model_creator = model_creator self.compile_args_creator = compile_args_creator self.config = {} if config is None else config self.verbose = verbose ray_ctx = RayContext.get() if "batch_size" in self.config: raise Exception( "Please do not specify batch_size in config. Input batch_size in the" " fit/evaluate function of the estimator instead.") if "inter_op_parallelism" not in self.config: self.config["inter_op_parallelism"] = 1 if "intra_op_parallelism" not in self.config: self.config[ "intra_op_parallelism"] = ray_ctx.ray_node_cpu_cores // workers_per_node if backend == "horovod": assert compile_args_creator is not None, "compile_args_creator should not be None," \ " when backend is set to horovod" params = { "model_creator": model_creator, "compile_args_creator": compile_args_creator, "config": self.config, "verbose": self.verbose, } if backend == "tf2": cores_per_node = ray_ctx.ray_node_cpu_cores // workers_per_node num_nodes = ray_ctx.num_ray_nodes * workers_per_node worker_class = ray.remote(num_cpus=cores_per_node)(TFRunner) self.remote_workers = [ worker_class.remote(**params) for i in range(0, num_nodes) ] ips = ray.get([ worker.get_node_ip.remote() for worker in self.remote_workers ]) ports = ray.get([ worker.find_free_port.remote() for worker in self.remote_workers ]) urls = [ "{ip}:{port}".format(ip=ips[i], port=ports[i]) for i in range(len(self.remote_workers)) ] # Get setup tasks in order to throw errors on failure ray.get([ worker.setup_distributed.remote(urls, i, len(self.remote_workers)) for i, worker in enumerate(self.remote_workers) ]) elif backend == "horovod": # it is necessary to call self.run first to set horovod environment from zoo.orca.learn.horovod.horovod_ray_runner import HorovodRayRunner horovod_runner = HorovodRayRunner( ray_ctx, worker_cls=TFRunner, worker_param=params, workers_per_node=workers_per_node) horovod_runner.run(lambda: print("worker initialized")) self.remote_workers = horovod_runner.remote_workers ray.get([ worker.setup_horovod.remote() for i, worker in enumerate(self.remote_workers) ]) else: raise Exception("Only \"tf2\" and \"horovod\" are legal " "values of backend, but got {}".format(backend)) self.num_workers = len(self.remote_workers)
def __init__(self, model_creator, compile_args_creator=None, config=None, verbose=False, backend="tf2", workers_per_node=1): """Sets up the TensorFlow trainer. Args: model_creator (dict -> Model): This function takes in the `config` dict and returns a compiled TF model. data_creator (dict -> tf.Dataset, tf.Dataset): Creates the training and validation data sets using the config. `config` dict is passed into the function. config (dict): configuration passed to 'model_creator', 'data_creator'. Also contains `fit_config`, which is passed into `model.fit(data, **fit_config)` and `evaluate_config` which is passed into `model.evaluate`. num_replicas (int): Sets number of workers used in distributed training. Workers will be placed arbitrarily across the cluster. use_gpu (bool): Enables all workers to use GPU. verbose (bool): Prints output of one model if true. """ self.model_creator = model_creator self.compile_args_creator = compile_args_creator self.config = {} if config is None else config self.verbose = verbose ray_ctx = RayContext.get() if "inter_op_parallelism" not in self.config: self.config["inter_op_parallelism"] = 1 if "intra_op_parallelism" not in self.config: self.config[ "intra_op_parallelism"] = ray_ctx.ray_node_cpu_cores // workers_per_node if backend == "horovod": assert compile_args_creator is not None, "compile_args_creator should not be None," \ " when backend is set to horovod" params = { "model_creator": model_creator, "compile_args_creator": compile_args_creator, "config": self.config, "verbose": self.verbose, } if backend == "tf2": cores_per_node = ray_ctx.ray_node_cpu_cores // workers_per_node num_nodes = ray_ctx.num_ray_nodes * workers_per_node worker_class = ray.remote(num_cpus=cores_per_node)(TFRunner) self.remote_workers = [ worker_class.remote(**params) for i in range(0, num_nodes) ] ips = ray.get([ worker.get_node_ip.remote() for worker in self.remote_workers ]) ports = ray.get([ worker.find_free_port.remote() for worker in self.remote_workers ]) urls = [ "{ip}:{port}".format(ip=ips[i], port=ports[i]) for i in range(len(self.remote_workers)) ] # Get setup tasks in order to throw errors on failure ray.get([ worker.setup_distributed.remote(urls, i, len(self.remote_workers)) for i, worker in enumerate(self.remote_workers) ]) elif backend == "horovod": # it is necessary to call self.run first to set horovod environment from zoo.orca.learn.horovod.horovod_ray_runner import HorovodRayRunner horovod_runner = HorovodRayRunner( ray_ctx, worker_cls=TFRunner, worker_param=params, workers_per_node=workers_per_node) horovod_runner.run(lambda: print("worker initialized")) self.remote_workers = horovod_runner.remote_workers ray.get([ worker.setup_horovod.remote() for i, worker in enumerate(self.remote_workers) ]) else: raise Exception("Only \"tf2\" and \"horovod\" are legal " "values of backend, but got {}".format(backend)) self.num_workers = len(self.remote_workers)