Ejemplo n.º 1
0
def test_scaling_config_validate_config_prohibited_class():
    # Check for prohibited keys
    scaling_config = {"num_workers": 2}
    with pytest.raises(ValueError):
        ensure_only_allowed_dataclass_keys_updated(
            ScalingConfigDataClass(**scaling_config),
            ["trainer_resources"],
        )
Ejemplo n.º 2
0
def test_scaling_config_validate_config_bad_allowed_keys():
    # Check for keys not present in dict
    scaling_config = {"num_workers": 2}
    with pytest.raises(ValueError) as exc_info:
        ensure_only_allowed_dataclass_keys_updated(
            ScalingConfigDataClass(**scaling_config),
            ["BAD_KEY"],
        )
    assert "BAD_KEY" in str(exc_info.value)
    assert "are not present in" in str(exc_info.value)
Ejemplo n.º 3
0
    def _validate_and_get_scaling_config_data_class(
        cls, dataclass_or_dict: Union[ScalingConfigDataClass, Dict[str, Any]]
    ) -> ScalingConfigDataClass:
        """Return scaling config dataclass after validating updated keys."""
        if isinstance(dataclass_or_dict, dict):
            dataclass_or_dict = ScalingConfigDataClass(**dataclass_or_dict)

        ensure_only_allowed_dataclass_keys_updated(
            dataclass=dataclass_or_dict,
            allowed_keys=cls._scaling_config_allowed_keys,
        )
        return dataclass_or_dict
Ejemplo n.º 4
0
def _convert_scaling_config_to_ray_params(
    scaling_config: ScalingConfig,
    ray_params_cls: Type["xgboost_ray.RayParams"],
    default_ray_params: Optional[Dict[str, Any]] = None,
) -> "xgboost_ray.RayParams":
    default_ray_params = default_ray_params or {}
    scaling_config_dataclass = ScalingConfigDataClass(**scaling_config)
    resources_per_worker = scaling_config_dataclass.additional_resources_per_worker
    num_workers = scaling_config_dataclass.num_workers
    cpus_per_worker = scaling_config_dataclass.num_cpus_per_worker
    gpus_per_worker = scaling_config_dataclass.num_gpus_per_worker

    ray_params = ray_params_cls(
        num_actors=int(num_workers),
        cpus_per_actor=int(cpus_per_worker),
        gpus_per_actor=int(gpus_per_worker),
        resources_per_actor=resources_per_worker,
        **default_ray_params,
    )

    return ray_params
Ejemplo n.º 5
0
def test_scaling_config_validate_config_valid_class():
    scaling_config = {"num_workers": 2}
    ensure_only_allowed_dataclass_keys_updated(
        ScalingConfigDataClass(**scaling_config), ["num_workers"])