def binary_op_dtype_propagation(*, strict: bool) -> RootValidatorType: def _impl(cls: Type[pydantic.BaseModel], values: RootValidatorValuesType) -> RootValidatorValuesType: common_dtype = verify_and_get_common_dtype( cls, [values["left"], values["right"]], strict=strict) if common_dtype: if isinstance(values["op"], ArithmeticOperator): if common_dtype is not DataType.BOOL: values["dtype"] = common_dtype else: raise ValueError( "Boolean expression is not allowed with arithmetic operation." ) elif isinstance(values["op"], LogicalOperator): if common_dtype is DataType.BOOL: values["dtype"] = DataType.BOOL else: raise ValueError( "Arithmetic expression is not allowed in boolean operation." ) elif isinstance(values["op"], ComparisonOperator): values["dtype"] = DataType.BOOL return values return root_validator(allow_reuse=True, skip_on_failure=True)(_impl)
def assign_stmt_dtype_validation(*, strict: bool) -> RootValidatorType: def _impl(cls: Type[pydantic.BaseModel], values: RootValidatorValuesType) -> RootValidatorValuesType: verify_and_get_common_dtype(cls, [values["left"], values["right"]], strict=strict) return values return root_validator(allow_reuse=True, skip_on_failure=True)(_impl)
def ternary_op_dtype_propagation(*, strict: bool) -> RootValidatorType: def _impl(cls: Type[pydantic.BaseModel], values: RootValidatorValuesType) -> RootValidatorValuesType: common_dtype = verify_and_get_common_dtype( cls, [values["true_expr"], values["false_expr"]], strict=strict) if common_dtype: values["dtype"] = common_dtype return values return root_validator(allow_reuse=True, skip_on_failure=True)(_impl)
def native_func_call_dtype_propagation(*, strict: bool = True ) -> RootValidatorType: def _impl(cls: Type[pydantic.BaseModel], values: RootValidatorValuesType) -> RootValidatorValuesType: # assumes all NativeFunction args have a common dtype common_dtype = verify_and_get_common_dtype(cls, values["args"], strict=strict) if common_dtype: values["dtype"] = common_dtype return values return root_validator(allow_reuse=True, skip_on_failure=True)(_impl)
def validate_dtype_is_set() -> RootValidatorType: def _impl(cls: Type[pydantic.BaseModel], values: RootValidatorValuesType) -> RootValidatorValuesType: dtype_nodes: List[Node] = [] for v in flatten_list(values.values()): if isinstance(v, Node): dtype_nodes.extend(v.iter_tree().if_hasattr("dtype")) nodes_without_dtype = [] for node in dtype_nodes: if not node.dtype: nodes_without_dtype.append(node) if len(nodes_without_dtype) > 0: raise ValueError( "Nodes without dtype detected {}".format(nodes_without_dtype)) return values return root_validator(allow_reuse=True, skip_on_failure=True)(_impl)
class Model(BaseModel): a: int = 1 b: str def repeat_b(cls, v): return v * 2 if validator_classmethod: repeat_b = classmethod(repeat_b) repeat_b = validator('b')(repeat_b) def example_root_validator(cls, values): root_val_values.append(values) if 'snap' in values.get('b', ''): raise ValueError('foobar') return dict(values, b='changed') if root_validator_classmethod: example_root_validator = classmethod(example_root_validator) example_root_validator = root_validator(example_root_validator)
def validate_symbol_refs() -> RootValidatorType: """Validate that symbol refs are found in a symbol table valid at the current scope.""" def _impl(cls: Type[pydantic.BaseModel], values: RootValidatorValuesType) -> RootValidatorValuesType: class SymtableValidator(NodeVisitor): def __init__(self) -> None: self.missing_symbols: List[str] = [] def visit_Node(self, node: Node, *, symtable: Dict[str, Any], **kwargs: Any) -> None: for name, metadata in node.__node_children__.items(): if isinstance(metadata["definition"].type_, type) and issubclass( metadata["definition"].type_, SymbolRef): if getattr(node, name) and getattr( node, name) not in symtable: self.missing_symbols.append(getattr(node, name)) if isinstance(node, SymbolTableTrait): symtable = {**symtable, **node.symtable_} self.generic_visit(node, symtable=symtable, **kwargs) @classmethod def apply(cls, node: Node, *, symtable: Dict[str, Any]) -> List[str]: instance = cls() instance.visit(node, symtable=symtable) return instance.missing_symbols missing_symbols = [] for v in values.values(): missing_symbols.extend( SymtableValidator.apply(v, symtable=values["symtable_"])) if len(missing_symbols) > 0: raise ValueError("Symbols {} not found.".format(missing_symbols)) return values return root_validator(allow_reuse=True, skip_on_failure=True)(_impl)
def validate_lvalue_dims( vertical_loop_type: Type[Node], decl_type: Type[Node] ) -> RootValidatorType: """ Validate lvalue dimensions using the root node symbol table. The following tree structure is expected:: Root(`SymTableTrait`) |- * |- `vertical_loop_type` |- loop_order: `LoopOrder` |- * |- AssignStmt(`AssignStmt`) |- left: `Node`, validated only if reference to `decl_type` in symtable |- symtable_: Symtable[name, Union[`decl_type`, *]] DeclType |- dimensions: `Tuple[bool, bool, bool]` Parameters ---------- vertical_loop_type: A node type with an `LoopOrder` attribute named `loop_order` decl_type: A declaration type with field dimension information in the format `Tuple[bool, bool, bool]` in an attribute named `dimensions`. """ def _impl( cls: Type[pydantic.BaseModel], values: RootValidatorValuesType ) -> RootValidatorValuesType: for _, children in values.items(): _LvalueDimsValidator(vertical_loop_type, decl_type).visit( children, symtable=values["symtable_"] ) return values return root_validator(allow_reuse=True, skip_on_failure=True)(_impl)
def add_root_validator( cls: typing.Type["Model"], validator: typing.Union[AnyCallable, classmethod], *, pre: bool = False, skip_on_failure: bool = False, allow_reuse: bool = True, index: int = -1, ): """ """ from inspect import signature from inspect import ismethod if isinstance(validator, classmethod) or ismethod(validator): validator = validator.__func__ # type:ignore func_name = validator.__name__ # first level validation if any([func_name in cls_.__dict__ for cls_ in cls.mro()]): raise ConfigError( f"{cls} already has same name '{func_name}' method or attribute!" ) if func_name in cls.__fields__: raise ConfigError( f"{cls} already has same name '{func_name}' field!") # evaluate through root_validator validator = root_validator(pre=pre, allow_reuse=allow_reuse, skip_on_failure=skip_on_failure)(validator) validator_config = getattr(validator, ROOT_VALIDATOR_CONFIG_KEY) sig = signature(validator_config.func) arg_list = list(sig.parameters.keys()) if len(arg_list) != 2: raise ConfigError( f"Invalid signature for root validator {func_name}: {sig}" ", should be: (cls, values).") if arg_list[0] != "cls": raise ConfigError( f"Invalid signature for root validator {func_name}: {sig}, " f'"{arg_list[0]}" not permitted as first argument, ' "should be: (cls, values).") # check function signature if validator_config.pre: if index == -1: cls.__pre_root_validators__.append(validator_config.func) else: cls.__pre_root_validators__.insert(index, validator_config.func) else: if index == -1: cls.__post_root_validators__.append( (validator_config.skip_on_failure, validator_config.func)) else: cls.__post_root_validators__.insert( index, (validator_config.skip_on_failure, validator_config.func)) # inject to class setattr(validator, "__manually_injected__", True) # noqa:B010 setattr(cls, func_name, validator)
def make_list_length_root_validator( *field_names, length_name: str, length_incr: int = 0, list_required_with_length: bool = False, min_length: int = 0, ): """ Get a root_validator that checks the correct length (and presence) of several list fields in an object. Args: *field_names (str): names of the instance variables that are a list and need checking. length_name (str): name of the instance variable that stores the expected length. length_incr (int): Optional extra increment of length value (e.g., to have +1 extra value in lists). list_required_with_length (obj:`bool`, optional): Whether each list *must* be present if the length attribute is present (and > 0) in the input values. Default: False. If False, list length is only checked for the lists that are not None. min_length (int): minimum for list length value, overrides length_name value if that is smaller. For example, to require list length 1 when length value is given as 0. """ def _get_incorrect_length_validation_message() -> str: """Make a string with a validation message, ready to be format()ed with field name and length name.""" incrstring = f" + {length_incr}" if length_incr != 0 else "" minstring = f" (and at least {min_length})" if min_length > 0 else "" return ( "Number of values for {} should be equal to the {} value" + incrstring + minstring + "." ) def _validate_listfield_length( field_name: str, field: Optional[List[Any]], requiredlength: int ): """Validate the length of a single field, which should be a list.""" if field is not None and len(field) != requiredlength: raise ValueError( _get_incorrect_length_validation_message().format( field_name, length_name ) ) if field is None and list_required_with_length and requiredlength > 0: raise ValueError( f"List {field_name} cannot be missing if {length_name} is given." ) return field def validate_correct_length(cls, values: dict): """The actual validator, will loop across all specified field names in outer function.""" length = values.get(length_name) if length is None: # length attribute not present, possibly defer validation to a subclass. return values requiredlength = max(length + length_incr, min_length) for field_name in field_names: field = values.get(field_name) values[field_name] = _validate_listfield_length( field_name, field, requiredlength ) return values return root_validator(allow_reuse=True)(validate_correct_length)
def validate_forbidden_fields(cls, values: dict): if (val := values.get(conditional_field_name)) is None or not comparison_func( val, conditional_value ): return values for field in field_names: if values.get(field) != None: raise ValueError( f"{field} is forbidden when {conditional_field_name} {operator_str(comparison_func)} {conditional_value}" ) return values return root_validator(allow_reuse=True)(validate_forbidden_fields) def get_required_fields_validator( *field_names, conditional_field_name: str, conditional_value: Any, comparison_func: Callable[[Any, Any], bool] = eq, ): """ Gets a validator that checks whether the fields are provided, if `conditional_field_name` is equal to `conditional_value`. The equality check can be overridden with another comparison operator function. Args: *field_names (str): Names of the instance variables that need to be validated. conditional_field_name (str): Name of the instance variable on which the fields are dependent.