def test_validate_accelerator_type(accelerator_type: str, expected: bool): # Invalid type raises specific ValueError if not expected: with pytest.raises(ValueError) as e: utils.validate_accelerator_type(accelerator_type) assert e.match(regexp=r"Given accelerator_type") # Valid type returns True else: assert utils.validate_accelerator_type(accelerator_type)
def _get_accelerator_type(self) -> Optional[str]: """Validates accelerator_type and returns the name of the accelerator. Returns: None if no accelerator or valid accelerator name. Raise: ValueError if accelerator type is invalid. """ # Raises ValueError if invalid accelerator_type utils.validate_accelerator_type(self.accelerator_type) accelerator_enum = getattr(gca_accelerator_type_compat.AcceleratorType, self.accelerator_type) if (accelerator_enum != gca_accelerator_type_compat.AcceleratorType. ACCELERATOR_TYPE_UNSPECIFIED): return self.accelerator_type