def test_scaling_config_validate_config_prohibited_class(): # Check for prohibited keys scaling_config = {"num_workers": 2} with pytest.raises(ValueError): ensure_only_allowed_dataclass_keys_updated( ScalingConfigDataClass(**scaling_config), ["trainer_resources"], )
def test_scaling_config_validate_config_bad_allowed_keys(): # Check for keys not present in dict scaling_config = {"num_workers": 2} with pytest.raises(ValueError) as exc_info: ensure_only_allowed_dataclass_keys_updated( ScalingConfigDataClass(**scaling_config), ["BAD_KEY"], ) assert "BAD_KEY" in str(exc_info.value) assert "are not present in" in str(exc_info.value)
def _validate_and_get_scaling_config_data_class( cls, dataclass_or_dict: Union[ScalingConfigDataClass, Dict[str, Any]] ) -> ScalingConfigDataClass: """Return scaling config dataclass after validating updated keys.""" if isinstance(dataclass_or_dict, dict): dataclass_or_dict = ScalingConfigDataClass(**dataclass_or_dict) ensure_only_allowed_dataclass_keys_updated( dataclass=dataclass_or_dict, allowed_keys=cls._scaling_config_allowed_keys, ) return dataclass_or_dict
def _convert_scaling_config_to_ray_params( scaling_config: ScalingConfig, ray_params_cls: Type["xgboost_ray.RayParams"], default_ray_params: Optional[Dict[str, Any]] = None, ) -> "xgboost_ray.RayParams": default_ray_params = default_ray_params or {} scaling_config_dataclass = ScalingConfigDataClass(**scaling_config) resources_per_worker = scaling_config_dataclass.additional_resources_per_worker num_workers = scaling_config_dataclass.num_workers cpus_per_worker = scaling_config_dataclass.num_cpus_per_worker gpus_per_worker = scaling_config_dataclass.num_gpus_per_worker ray_params = ray_params_cls( num_actors=int(num_workers), cpus_per_actor=int(cpus_per_worker), gpus_per_actor=int(gpus_per_worker), resources_per_actor=resources_per_worker, **default_ray_params, ) return ray_params
def test_scaling_config_validate_config_valid_class(): scaling_config = {"num_workers": 2} ensure_only_allowed_dataclass_keys_updated( ScalingConfigDataClass(**scaling_config), ["num_workers"])