def test_scaling_config_validate_config_prohibited_class(): # Check for prohibited keys scaling_config = {"num_workers": 2} with pytest.raises(ValueError): ensure_only_allowed_dataclass_keys_updated( ScalingConfigDataClass(**scaling_config), ["trainer_resources"], )
def _validate_and_get_scaling_config_data_class( cls, dataclass_or_dict: Union[ScalingConfigDataClass, Dict[str, Any]] ) -> ScalingConfigDataClass: """Return scaling config dataclass after validating updated keys.""" if isinstance(dataclass_or_dict, dict): ensure_only_allowed_dict_keys_set(dataclass_or_dict, cls._scaling_config_allowed_keys) scaling_config_dataclass = ScalingConfigDataClass( **dataclass_or_dict) return scaling_config_dataclass ensure_only_allowed_dataclass_keys_updated( dataclass=dataclass_or_dict, allowed_keys=cls._scaling_config_allowed_keys, ) return dataclass_or_dict
def _convert_scaling_config_to_ray_params( scaling_config: ScalingConfig, ray_params_cls: Type["xgboost_ray.RayParams"], default_ray_params: Optional[Dict[str, Any]] = None, ) -> "xgboost_ray.RayParams": default_ray_params = default_ray_params or {} scaling_config_dataclass = ScalingConfigDataClass(**scaling_config) resources_per_worker = scaling_config_dataclass.additional_resources_per_worker num_workers = scaling_config_dataclass.num_workers cpus_per_worker = scaling_config_dataclass.num_cpus_per_worker gpus_per_worker = scaling_config_dataclass.num_gpus_per_worker ray_params = ray_params_cls( num_actors=int(num_workers), cpus_per_actor=int(cpus_per_worker), gpus_per_actor=int(gpus_per_worker), resources_per_actor=resources_per_worker, **default_ray_params, ) return ray_params
def training_loop(self) -> None: scaling_config_dataclass = ScalingConfigDataClass( **self.scaling_config) train_loop_per_worker = construct_train_func( self.train_loop_per_worker, self.train_loop_config, fn_arg_name="train_loop_per_worker", ) additional_resources_per_worker = ( scaling_config_dataclass.additional_resources_per_worker) backend_executor = BackendExecutor( backend_config=self.backend_config, num_workers=scaling_config_dataclass.num_workers, num_cpus_per_worker=scaling_config_dataclass.num_cpus_per_worker, num_gpus_per_worker=scaling_config_dataclass.num_gpus_per_worker, additional_resources_per_worker=additional_resources_per_worker, max_retries=0, ) checkpoint_manager = _DataParallelCheckpointManager() checkpoint_manager.on_init(preprocessor=self.preprocessor) # Start the remote actors. backend_executor.start(initialization_hook=None) if self.resume_from_checkpoint: resume_checkpoint_dict = self.resume_from_checkpoint.to_dict() else: resume_checkpoint_dict = None # Tell Ray Train to only shard the train dataset and not the other datasets. # This is purely an implementation detail and users do not need to know about # this. # TODO(amog): Refactor this to remove hack and make this more modular. # TrainingIterator should accept a generic custom_ingest_func that contains # the logic for how to split the Datasets. updated_dataset_dict = {} for key, value in self.datasets.items(): if key == TRAIN_DATASET_KEY: updated_dataset_dict[key] = value else: # Ray Train will strip out the added string before exposing to users. updated_dataset_dict[key + "_NO-SHARD"] = value # TODO(amog): Have TrainingIterator also accept a checkpoint ObjectRef instead # of just a Dict. training_iterator = TrainingIterator( backend_executor=backend_executor, backend_config=self.backend_config, train_func=train_loop_per_worker, dataset=updated_dataset_dict if len(updated_dataset_dict) > 0 else None, checkpoint_manager=checkpoint_manager, checkpoint=resume_checkpoint_dict, checkpoint_strategy=None, ) for results in training_iterator: # TODO(ml-team): add ability to report results from multiple workers. first_worker_results = results[0] tune.report(**first_worker_results) # Shutdown workers. backend_executor.shutdown()
def default_resource_request(cls, config): updated_scaling_config = config.get("scaling_config", scaling_config) scaling_config_dataclass = ScalingConfigDataClass( **updated_scaling_config) return scaling_config_dataclass.as_placement_group_factory()
def test_scaling_config_validate_config_valid_class(): scaling_config = {"num_workers": 2} ensure_only_allowed_dataclass_keys_updated( ScalingConfigDataClass(**scaling_config), ["num_workers"])