def normalize_shape(shape: ConvertibleToShape, ns: Optional[Namespace] = None) -> ShapeType: """ :param ns: if a namespace is given, extra checks are performed to ensure that expressions are well-defined. """ def normalize_shape_component( s: ConvertibleToShapeComponent) -> ScalarExpression: if isinstance(s, str): s = scalar_expr.parse(s) if isinstance(s, int): if s < 0: raise ValueError( f"size parameter must be nonnegative (got '{s}')") elif isinstance(s, prim.Expression): # TODO: check expression affine-ness _ShapeChecker()(s) return s if isinstance(shape, str): shape = scalar_expr.parse(shape) from numbers import Number if isinstance(shape, (Number, prim.Expression)): shape = (shape, ) return tuple(normalize_shape_component(s) for s in shape)
def normalize_shape_component( s: ConvertibleToShapeComponent) -> ScalarExpression: if isinstance(s, str): s = scalar_expr.parse(s) if isinstance(s, int): if s < 0: raise ValueError( f"size parameter must be nonnegative (got '{s}')") elif isinstance(s, prim.Expression): # TODO: check expression affine-ness _ShapeChecker()(s) return s