def _deduce_Concat(self: ast.Binop, ctx: DeduceCtx) -> ConcreteType: """Deduces the concrete type of a concatenate Binop AST node.""" lhs_type = deduce(self.lhs, ctx) resolved_lhs_type = resolve(lhs_type, ctx) rhs_type = deduce(self.rhs, ctx) resolved_rhs_type = resolve(rhs_type, ctx) # Array-ness must be the same on both sides. if (isinstance(resolved_lhs_type, ArrayType) != isinstance( resolved_rhs_type, ArrayType)): raise XlsTypeError( self.span, resolved_lhs_type, resolved_rhs_type, 'Attempting to concatenate array/non-array values together.') if (isinstance(resolved_lhs_type, ArrayType) and resolved_lhs_type.get_element_type() != resolved_rhs_type.get_element_type()): raise XlsTypeError( self.span, resolved_lhs_type, resolved_rhs_type, 'Array concatenation requires element types to be the same.') new_size = resolved_lhs_type.size + resolved_rhs_type.size # pytype: disable=attribute-error if isinstance(resolved_lhs_type, ArrayType): return ArrayType(resolved_lhs_type.get_element_type(), new_size) return BitsType(signed=False, size=new_size)
def _deduce_Index(self: ast.Index, ctx: DeduceCtx) -> ConcreteType: # pytype: disable=wrong-arg-types """Deduces the concrete type of an Index AST node.""" lhs_type = deduce(self.lhs, ctx) # Check whether this is a slice-based indexing operations. if isinstance(self.index, (ast.Slice, ast.WidthSlice)): return _deduce_slice_type(self, ctx, lhs_type) index_type = deduce(self.index, ctx) if isinstance(lhs_type, TupleType): if not isinstance(self.index, ast.Number): raise XlsTypeError(self.index.span, index_type, None, 'Tuple index is not a literal number.') index_value = self.index.get_value_as_int() if index_value >= lhs_type.get_tuple_length(): raise XlsTypeError( self.index.span, lhs_type, None, 'Tuple index {} is out of range for this tuple type.'.format( index_value)) return lhs_type.get_unnamed_members()[index_value] if not isinstance(lhs_type, ArrayType): raise TypeInferenceError(self.lhs.span, lhs_type, 'Value to index is not an array.') index_ok = isinstance(index_type, BitsType) and not isinstance(index_type, ArrayType) if not index_ok: raise XlsTypeError(self.index.span, index_type, None, 'Index type is not scalar bits.') return lhs_type.get_element_type()
def _deduce_Binop(self: ast.Binop, ctx: DeduceCtx) -> ConcreteType: # pytype: disable=wrong-arg-types """Deduces the concrete type of a Binop AST node.""" # Concatenation is handled differently from other binary operations. if self.kind == ast.BinopKind.CONCAT: return _deduce_Concat(self, ctx) lhs_type = deduce(self.lhs, ctx) rhs_type = deduce(self.rhs, ctx) resolved_lhs_type = resolve(lhs_type, ctx) resolved_rhs_type = resolve(rhs_type, ctx) if resolved_lhs_type != resolved_rhs_type: raise XlsTypeError( self.span, resolved_lhs_type, resolved_rhs_type, 'Could not deduce type for binary operation {0} ({0!r}).'.format( self.kind)) # Enums only support a more limited set of binary operations. if isinstance(lhs_type, EnumType) and self.kind not in ast.BinopKind.ENUM_OK_KINDS: raise XlsTypeError( self.span, resolved_lhs_type, None, "Cannot use '{}' on values with enum type {}".format( self.kind.value, lhs_type.nominal_type.identifier)) if self.kind in ast.BinopKind.COMPARISON_KINDS: return ConcreteType.U1 return resolved_lhs_type
def instantiate(self) -> Tuple[ConcreteType, SymbolicBindings]: """Updates symbolic bindings for the member types according to arg_types. Instantiates the parameters of struct_type according to the presented arg_types; e.g. when a bits[3,4] argument is passed to a bits[N,M] parameter, we note that N=3 and M=4 for resolution in the return type. Returns: The return type of the struct_type, with parametric types instantiated in accordance with the presented argument types. """ # Walk through all the members/args to collect symbolic bindings. for i, (member_type, arg_type) in enumerate(zip(self.member_types, self.arg_types)): member_type = self._instantiate_one_arg(i, member_type, arg_type) logging.vlog( 3, 'Post-instantiation; memno: %d; member_type: %s; struct_type: %s', i, member_type, arg_type) if member_type != arg_type: message = 'Mismatch between member and argument types.' if str(member_type) == str(arg_type): message += ' {!r} vs {!r}'.format(member_type, arg_type) raise XlsTypeError(self.span, member_type, arg_type, suffix=message) # Resolve the struct type according to the bindings we collected. resolved = self._resolve(self.struct_type) logging.vlog(3, 'Resolved struct type from %s to %s', self.struct_type, resolved) return resolved, tuple(sorted(self.symbolic_bindings.items()))
def _bind_names(name_def_tree: ast.NameDefTree, type_: ConcreteType, ctx: DeduceCtx) -> None: """Binds names in name_def_tree to corresponding type given in type_.""" if name_def_tree.is_leaf(): name_def = name_def_tree.get_leaf() ctx.type_info[name_def] = type_ return if not isinstance(type_, TupleType): raise XlsTypeError( name_def_tree.span, type_, rhs_type=None, suffix='Expected a tuple type for these names, but got {}.'.format( type_)) if len(name_def_tree.tree) != type_.get_tuple_length(): raise TypeInferenceError( name_def_tree.span, type_, 'Could not bind names, names are mismatched in number vs type; at ' 'this level of the tuple: {} names, {} types.'.format( len(name_def_tree.tree), type_.get_tuple_length())) for subtree, subtype in zip(name_def_tree.tree, type_.get_unnamed_members()): ctx.type_info[subtree] = subtype _bind_names(subtree, subtype, ctx)
def _check_function_params(f: ast.Function, ctx: deduce.DeduceCtx) -> List[ConcreteType]: """Checks the function's parametrics' and arguments' types.""" for parametric in f.parametric_bindings: parametric_binding_type = deduce.deduce(parametric.type_, ctx) assert isinstance(parametric_binding_type, ConcreteType) if parametric.expr: # TODO(hjmontero): 2020-07-06 fully document the behavior of parametric # function calls in parametric expressions. expr_type = deduce.deduce(parametric.expr, ctx) if expr_type != parametric_binding_type: raise XlsTypeError( parametric.span, parametric_binding_type, expr_type, suffix='Annotated type of derived parametric ' 'value did not match inferred type.') ctx.type_info[parametric.name] = parametric_binding_type param_types = [] for param in f.params: logging.vlog(2, 'Checking param: %s', param) param_type = deduce.deduce(param, ctx) assert isinstance(param_type, ConcreteType), param_type param_types.append(param_type) ctx.type_info[param.name] = param_type return param_types
def _deduce_Invocation(self: ast.Invocation, ctx: DeduceCtx) -> ConcreteType: # pytype: disable=wrong-arg-types """Deduces the concrete type of an Invocation AST node.""" logging.vlog(5, 'Deducing type for invocation: %s', self) arg_types = [] _, fn_symbolic_bindings = ctx.fn_stack[-1] for arg in self.args: try: arg_types.append(resolve(deduce(arg, ctx), ctx)) except TypeMissingError as e: # These nodes could be ModRefs or NameRefs. callee_is_map = isinstance( self.callee, ast.NameRef) and self.callee.name_def.identifier == 'map' arg_is_builtin = isinstance( arg, ast.NameRef ) and arg.name_def.identifier in dslx_builtins.PARAMETRIC_BUILTIN_NAMES if callee_is_map and arg_is_builtin: invocation = _create_element_invocation( ctx.module, self.span, arg, self.args[0]) arg_types.append(resolve(deduce(invocation, ctx), ctx)) else: raise try: # This will get us the type signature of the function. # If the function is parametric, we won't check its body # until after we have symbolic bindings for it callee_type = deduce(self.callee, ctx) except TypeMissingError as e: e.span = self.span e.user = self raise if not isinstance(callee_type, FunctionType): raise XlsTypeError(self.callee.span, callee_type, None, 'Callee does not have a function type.') if isinstance(self.callee, ast.ModRef): imported_module, _ = ctx.type_info.get_imported(self.callee.mod) callee_name = self.callee.value callee_fn = imported_module.get_function(callee_name) else: assert isinstance(self.callee, ast.NameRef), self.callee callee_name = self.callee.identifier callee_fn = ctx.module.get_function(callee_name) self_type, callee_sym_bindings = parametric_instantiator.instantiate_function( self.span, callee_type, tuple(arg_types), ctx, callee_fn.parametric_bindings) caller_sym_bindings = tuple(fn_symbolic_bindings.items()) ctx.type_info.add_invocation_symbolic_bindings(self, caller_sym_bindings, callee_sym_bindings) if callee_fn.is_parametric(): # Now that we have callee_sym_bindings, let's use them to typecheck the body # of callee_fn to make sure these values actually work _check_parametric_invocation(callee_fn, self, callee_sym_bindings, ctx) return self_type
def instantiate(self) -> Tuple[ConcreteType, SymbolicBindings]: """Updates symbolic bindings for the parameter types according to arg_types. Instantiates the parameters of function_type according to the presented arg_types; e.g. when a bits[3,4] argument is passed to a bits[N,M] parameter, we note that N=3 and M=4 for resolution in the return type. Returns: The return type of the function_type, with parametric types instantiated in accordance with the presented argument types. """ # Walk through all the params/args to collect symbolic bindings. for i, (param_type, arg_type) in enumerate( zip(self.function_type.get_function_params(), self.arg_types)): # pytype: disable=attribute-error param_type = self._instantiate_one_arg(i, param_type, arg_type) logging.vlog( 3, 'Post-instantiation; paramno: %d; param_type: %s; arg_type: %s', i, param_type, arg_type) if param_type != arg_type: message = 'Mismatch between parameter and argument types.' if str(param_type) == str(arg_type): message += ' {!r} vs {!r}'.format(param_type, arg_type) raise XlsTypeError(self.span, param_type, arg_type, suffix=message) # Resolve the return type according to the bindings we collected. orig = self.function_type.get_function_return_type() # pytype: disable=attribute-error resolved = self._resolve(orig) logging.vlog(2, 'Resolved return type from %s to %s', orig, resolved) return resolved, tuple(sorted(self.symbolic_bindings.items()))
def _symbolic_bind(self, param_type: ConcreteType, arg_type: ConcreteType) -> None: """Binds symbols present in param_type according to value of arg_type.""" assert isinstance(param_type, ConcreteType), repr(param_type) assert isinstance(arg_type, ConcreteType), repr(arg_type) if isinstance(param_type, BitsType): self._symbolic_bind_bits(param_type, arg_type) elif isinstance(param_type, EnumType): assert param_type.nominal_type == arg_type.nominal_type # If the enums are the same, we do the same thing as we do with bits # (ignore the primitive and symbolic bind the dims). self._symbolic_bind_bits(param_type, arg_type) elif isinstance(param_type, TupleType): if param_type.nominal_type != arg_type.nominal_type: raise XlsTypeError( self.span, param_type, arg_type, suffix='parameter type name: {}; argument type name: {}.'. format( repr(param_type.nominal_type.identifier) if param_type.nominal_type else '<none>', repr(arg_type.nominal_type.identifier) if arg_type.nominal_type else '<none>')) self._symbolic_bind_tuple(param_type, arg_type) elif isinstance(param_type, ArrayType): self._symbolic_bind_array(param_type, arg_type) elif isinstance(param_type, FunctionType): self._symbolic_bind_function(param_type, arg_type) else: raise NotImplementedError( 'Bind symbols in parameter type {} @ {}'.format( param_type, self.span))
def test_stringify(self): # Test without a suffix. t = BitsType(signed=False, size=3) fake_pos = Pos('<fake>', 9, 10) fake_span = Span(fake_pos, fake_pos) e = XlsTypeError(fake_span, t, t) self.assertEndsWith(str(e), '@ <fake>:10:11-10:11')
def _check_function(f: Function, ctx: deduce.DeduceCtx) -> None: """Validates type annotations on parameters/return type of f are consistent. Args: f: The function to type check. ctx: Wraps a node_to_type, a mapping of AST node to its deduced type; (free-variable) references are resolved via this dictionary. Raises: XlsTypeError: When the return type deduced is inconsistent with the return type annotation on "f". """ fn_name, _ = ctx.fn_stack[-1] # First, get the types of the function's parametrics, args, and return type if f.is_parametric() and f.name.identifier == fn_name: # Parametric functions are evaluated per invocation. If we're currently # inside of this function, it must mean that we already have the type # signature and now we just need to evaluate the body. assert f in ctx.node_to_type, f annotated_return_type = ctx.node_to_type[f].return_type # pytype: disable=attribute-error param_types = list(ctx.node_to_type[f].params) # pytype: disable=attribute-error else: logging.vlog(1, 'Type-checking sig for function: %s', f) param_types = _check_function_params(f, ctx) if f.is_parametric(): # We just needed the type signature so that we can instantiate this # invocation. Let's return this for now and typecheck the body once we # have symbolic bindings. annotated_return_type = (deduce.deduce(f.return_type, ctx) if f.return_type else ConcreteType.NIL) ctx.node_to_type[f.name] = ctx.node_to_type[f] = FunctionType( tuple(param_types), annotated_return_type) return logging.vlog(1, 'Type-checking body for function: %s', f) # Second, typecheck the return type of the function. # NOTE: if there is no annotated return type, we assume NIL. annotated_return_type = (deduce.deduce(f.return_type, ctx) if f.return_type else ConcreteType.NIL) resolved_return_type = deduce.resolve(annotated_return_type, ctx) # Third, typecheck the body of the function body_return_type = deduce.deduce(f.body, ctx) resolved_body_type = deduce.resolve(body_return_type, ctx) # Finally, assert type consistency between body and annotated return type. if resolved_return_type != resolved_body_type: raise XlsTypeError( f.body.span, resolved_body_type, resolved_return_type, suffix='Return type of function body for "{}" did not match the ' 'annotated return type.'.format(f.name.identifier)) ctx.node_to_type[f.name] = ctx.node_to_type[f] = FunctionType( tuple(param_types), body_return_type)
def _deduce_Ternary(self: ast.Ternary, ctx: DeduceCtx) -> ConcreteType: # pytype: disable=wrong-arg-types """Deduces the concrete type of a Ternary AST node.""" test_type = deduce(self.test, ctx) resolved_test_type = resolve(test_type, ctx) if resolved_test_type != ConcreteType.U1: raise XlsTypeError( self.span, resolved_test_type, ConcreteType.U1, 'Test type for conditional expression is not "bool"') cons_type = deduce(self.consequent, ctx) resolved_cons_type = resolve(cons_type, ctx) alt_type = deduce(self.alternate, ctx) resolved_alt_type = resolve(alt_type, ctx) if resolved_cons_type != resolved_alt_type: raise XlsTypeError( self.span, resolved_cons_type, resolved_alt_type, 'Ternary consequent type (in the "then" clause) did not match ' 'alternate type (in the "else" clause)') return resolved_cons_type
def check_test(t: ast.Test, ctx: deduce.DeduceCtx) -> None: """Typechecks a test (body) within a module.""" body_return_type = deduce.deduce(t.body, ctx) nil = ConcreteType.NIL if body_return_type != nil: raise XlsTypeError( t.body.span, body_return_type, nil, suffix='Return type of test body for "{}" did not match the ' 'expected test return type (nil).'.format(t.name.identifier))
def _deduce_While(self: ast.While, ctx: DeduceCtx) -> ConcreteType: # pytype: disable=wrong-arg-types """Deduces the concrete type of a While AST node.""" init_type = deduce(self.init, ctx) test_type = deduce(self.test, ctx) resolved_init_type = resolve(init_type, ctx) resolved_test_type = resolve(test_type, ctx) if resolved_test_type != ConcreteType.U1: raise XlsTypeError(self.test.span, test_type, ConcreteType.U1, 'Expect while-loop test to be a bool value.') body_type = deduce(self.body, ctx) resolved_body_type = resolve(body_type, ctx) if resolved_init_type != resolved_body_type: raise XlsTypeError( self.span, init_type, body_type, "While-loop init value type did not match while-loop body's " 'result type.') return resolved_init_type
def _symbolic_bind_dims(self, param_type: ConcreteType, arg_type: ConcreteType) -> None: """Binds parametric symbols in param_type according to arg_type.""" # Create bindings for symbolic parameter dimensions based on argument # values passed. param_dim = param_type.size.value arg_dim = arg_type.size.value if not isinstance(param_dim, parametric_expression.ParametricSymbol): return pdim_name = param_dim.identifier if (pdim_name in self.symbolic_bindings and self.symbolic_bindings[pdim_name] != arg_dim): if self.constraints[pdim_name]: # Error on violated constraint. raise XlsTypeError( self.span, BitsType(signed=False, size=self.symbolic_bindings[pdim_name]), arg_type, suffix=f'Parametric constraint violated, saw {pdim_name} ' f'= {self.constraints[pdim_name]} ' f'= {self.symbolic_bindings[pdim_name]}; ' f'then {pdim_name} = {arg_dim}') else: # Error on conflicting argument types. raise XlsTypeError( self.span, param_type, arg_type, suffix= 'Parametric value {} was bound to different values at ' 'different places in invocation; saw: {!r}; then: {!r}'. format(pdim_name, self.symbolic_bindings[pdim_name], arg_dim)) logging.vlog(2, 'Binding %r to %s', pdim_name, arg_dim) self.symbolic_bindings[pdim_name] = arg_dim
def _deduce_Array(self: ast.Array, ctx: DeduceCtx) -> ConcreteType: # pytype: disable=wrong-arg-types """Deduces the concrete type of an Array AST node.""" member_types = [deduce(m, ctx) for m in self.members] resolved_type0 = resolve(member_types[0], ctx) for i, x in enumerate(member_types[1:], 1): resolved_x = resolve(x, ctx) logging.vlog(5, 'array member type %d: %s', i, resolved_x) if resolved_x != resolved_type0: raise XlsTypeError( self.members[i].span, resolved_type0, resolved_x, 'Array member did not have same type as other members.') inferred = ArrayType(resolved_type0, len(member_types)) if not self.type_: return inferred annotated = deduce(self.type_, ctx) if not isinstance(annotated, ArrayType): raise XlsTypeError(self.span, annotated, None, 'Array was not annotated with an array type.') resolved_element_type = resolve(annotated.get_element_type(), ctx) if resolved_element_type != resolved_type0: raise XlsTypeError( self.span, resolved_element_type, resolved_type0, 'Annotated element type did not match inferred element type.') if self.has_ellipsis: # Since there are ellipsis, we determine the size from the annotated type. # We've already checked the element types lined up. return annotated else: if annotated.size != len(member_types): raise XlsTypeError( self.span, annotated, inferred, 'Annotated array size {!r} does not match inferred array size {!r}.' .format(annotated.size, len(member_types))) return inferred
def _deduce_Cast(self: ast.Cast, ctx: DeduceCtx) -> ConcreteType: # pytype: disable=wrong-arg-types """Deduces the concrete type of a Cast AST node.""" type_result = deduce(self.type_, ctx) expr_type = deduce(self.expr, ctx) resolved_type_result = resolve(type_result, ctx) resolved_expr_type = resolve(expr_type, ctx) if not _is_acceptable_cast(from_=resolved_type_result, to=resolved_expr_type): raise XlsTypeError( self.span, expr_type, type_result, 'Cannot cast from expression type {} to {}.'.format( resolved_expr_type, resolved_type_result)) return resolved_type_result
def _deduce_Enum(self: ast.Enum, ctx: DeduceCtx) -> ConcreteType: # pytype: disable=wrong-arg-types """Deduces the concrete type of a Enum AST node.""" resolved_type = resolve(deduce(self.type_, ctx), ctx) if not isinstance(resolved_type, BitsType): raise XlsTypeError(self.span, resolved_type, None, 'Underlying type for an enum must be a bits type.') # Grab the bit count of the Enum's underlying type. bit_count = resolved_type.get_total_bit_count() self.set_signedness(resolved_type.get_signedness()) result = EnumType(self, bit_count) for name, value in self.values: # Note: the parser places the type_ from the enum on the value when it is # a number, so this deduction flags inappropriate numbers. deduce(value, ctx) ctx.type_info[name] = ctx.type_info[value] = result ctx.type_info[self.name] = ctx.type_info[self] = result return result
def _verify_constraints(self) -> None: """Verifies that all parametrics adhere to signature constraints. Take the following function signature for example: fn [X: u32, Y: u32 = X + X] f(x: bits[X], y: bits[Y]) -> bits[Y] The parametric Y has two constraints based only off the signature: it must match the bitwidth of the argument y and it must be equal to X + X. This function is responsible for computing any derived parametrics and asserting that their values are consistent with other constraints (arg types). """ for binding, constraint in self.constraints.items(): if constraint is None: # e.g. [X: u32] continue try: fn_name, fn_symbolic_bindings = self.ctx.fn_stack[-1] fn_ctx = (self.ctx.module.name, fn_name, tuple(fn_symbolic_bindings.items())) result = self.ctx.interpret_expr(self.ctx.module, self.ctx.type_info, self.symbolic_bindings, self.bit_widths, constraint, fn_ctx=fn_ctx) except KeyError as e: # We haven't seen enough bindings to evaluate this constraint. continue if binding in self.symbolic_bindings.keys(): if result != self.symbolic_bindings[binding]: raise XlsTypeError( self.span, BitsType(signed=False, size=self.symbolic_bindings[binding]), BitsType(signed=False, size=result), suffix= f'Parametric constraint violated, saw {binding} = {constraint} = {result}; ' f'then {binding} = {self.symbolic_bindings[binding]}') else: self.symbolic_bindings[binding] = result
def _deduce_Match(self: ast.Match, ctx: DeduceCtx) -> ConcreteType: # pytype: disable=wrong-arg-types """Deduces the concrete type of a Match AST node.""" matched = deduce(self.matched, ctx) for arm in self.arms: for pattern in arm.patterns: _unify(pattern, matched, ctx) arm_types = tuple(deduce(arm, ctx) for arm in self.arms) resolved_arm0_type = resolve(arm_types[0], ctx) for i, arm_type in enumerate(arm_types[1:], 1): resolved_arm_type = resolve(arm_type, ctx) if resolved_arm_type != resolved_arm0_type: raise XlsTypeError( self.arms[i].span, resolved_arm_type, resolved_arm0_type, 'This match arm did not have the same type as preceding match arms.' ) return resolved_arm0_type
def _deduce_For(self: ast.For, ctx: DeduceCtx) -> ConcreteType: # pytype: disable=wrong-arg-types """Deduces the concrete type of a For AST node.""" init_type = deduce(self.init, ctx) annotated_type = deduce(self.type_, ctx) _bind_names(self.names, annotated_type, ctx) body_type = deduce(self.body, ctx) deduce(self.iterable, ctx) resolved_init_type = resolve(init_type, ctx) resolved_body_type = resolve(body_type, ctx) if resolved_init_type != resolved_body_type: raise XlsTypeError( self.span, resolved_init_type, resolved_body_type, "For-loop init value type did not match for-loop body's result type." ) # TODO(leary): 2019-02-19 Type check annotated_type (the bound names each # iteration) against init_type/body_type -- this requires us to understand # how iterables turn into induction values. return resolved_init_type
def _instantiate_one_arg(self, i: int, param_type: ConcreteType, arg_type: ConcreteType) -> ConcreteType: """Binds param_type via arg_type, updating symbolic bindings.""" assert isinstance(param_type, ConcreteType), repr(param_type) assert isinstance(arg_type, ConcreteType), repr(arg_type) # Check parameter and arg types are the same kind. if type(param_type) != type(arg_type): # pylint: disable=unidiomatic-typecheck raise XlsTypeError( self.span, param_type, arg_type, suffix='Parameter {} and argument types are different kinds ' '({} vs {}).'.format(i, param_type.get_debug_type_name(), arg_type.get_debug_type_name())) logging.vlog( 3, 'Symbolically binding param_type %d %s against arg_type %s', i, param_type, arg_type) self._symbolic_bind(param_type, arg_type) resolved = self._resolve(param_type) logging.vlog(3, 'Resolved param_type: %s', resolved) return resolved
def _deduce_Let(self: ast.Let, ctx: DeduceCtx) -> ConcreteType: # pytype: disable=wrong-arg-types """Deduces the concrete type of a Let AST node.""" rhs_type = deduce(self.rhs, ctx) resolved_rhs_type = resolve(rhs_type, ctx) if self.type_ is not None: concrete_type = deduce(self.type_, ctx) resolved_concrete_type = resolve(concrete_type, ctx) if resolved_rhs_type != resolved_concrete_type: raise XlsTypeError( self.rhs.span, resolved_concrete_type, resolved_rhs_type, 'Annotated type did not match inferred type of right hand side.' ) _bind_names(self.name_def_tree, resolved_rhs_type, ctx) if self.const: deduce(self.const, ctx) return deduce(self.body, ctx)
def _deduce_Struct(self: ast.Struct, ctx: DeduceCtx) -> ConcreteType: # pytype: disable=wrong-arg-types """Returns the concrete type for a (potentially parametric) struct.""" for parametric in self.parametric_bindings: parametric_binding_type = deduce(parametric.type_, ctx) assert isinstance(parametric_binding_type, ConcreteType) if parametric.expr: expr_type = deduce(parametric.expr, ctx) if expr_type != parametric_binding_type: raise XlsTypeError( parametric.span, parametric_binding_type, expr_type, suffix='Annotated type of derived parametric ' 'value did not match inferred type.') ctx.type_info[parametric.name] = parametric_binding_type members = tuple( (k.identifier, resolve(deduce(m, ctx), ctx)) for k, m in self.members) result = ctx.type_info[self.name] = TupleType(members, self) logging.vlog(5, 'Deduced type for struct %s => %s; type_info: %r', self, result, ctx.type_info) return result
def _deduce_slice_type(self: ast.Index, ctx: DeduceCtx, lhs_type: ConcreteType) -> ConcreteType: """Deduces the concrete type of an Index AST node with a slice spec.""" index_slice = self.index assert isinstance(index_slice, (ast.Slice, ast.WidthSlice)), index_slice # TODO(leary): 2019-10-28 Only slicing bits types for now, and only with # number ast nodes, generalize to arrays and constant expressions. if not isinstance(lhs_type, BitsType): raise XlsTypeError(self.span, lhs_type, None, 'Value to slice is not of "bits" type.') bit_count = lhs_type.get_total_bit_count() if isinstance(index_slice, ast.WidthSlice): start = index_slice.start if isinstance(start, ast.Number) and start.type_ is None: start_type = lhs_type.to_ubits() resolved_start_type = resolve(start_type, ctx) if not bit_helpers.fits_in_bits( start.get_value_as_int(), resolved_start_type.get_total_bit_count()): raise TypeInferenceError( start.span, resolved_start_type, 'Cannot fit {} in {} bits (inferred from bits to slice).'. format(start.get_value_as_int(), resolved_start_type.get_total_bit_count())) ctx.type_info[start] = start_type else: start_type = deduce(start, ctx) # Check the start is unsigned. if start_type.signed: raise TypeInferenceError( start.span, type_=start_type, suffix='Start index for width-based slice must be unsigned.') width_type = deduce(index_slice.width, ctx) if isinstance(width_type.get_total_bit_count(), int) and isinstance( lhs_type.get_total_bit_count(), int) and width_type.get_total_bit_count( ) > lhs_type.get_total_bit_count(): raise XlsTypeError( start.span, lhs_type, width_type, 'Slice type must have <= original number of bits; attempted slice from {} to {} bits.' .format(lhs_type.get_total_bit_count(), width_type.get_total_bit_count())) # Check the width type is bits-based (no enums, since value could be out # of range of the enum values). if not isinstance(width_type, BitsType): raise TypeInferenceError( self.span, type_=width_type, suffix='Require a bits-based type for width-based slice.') # The width type is the thing returned from the width-slice. return width_type assert isinstance(index_slice, ast.Slice), index_slice limit = index_slice.limit.get_value_as_int() if index_slice.limit else None # PyType has trouble figuring out that start is definitely an Number at this # point. start = index_slice.start assert isinstance(start, (ast.Number, type(None))) start = start.get_value_as_int() if start else None _, fn_symbolic_bindings = ctx.fn_stack[-1] if isinstance(bit_count, ParametricExpression): bit_count = bit_count.evaluate(fn_symbolic_bindings) start, width = bit_helpers.resolve_bit_slice_indices( bit_count, start, limit) key = tuple(fn_symbolic_bindings.items()) ctx.type_info.add_slice_start_width(index_slice, key, (start, width)) return BitsType(signed=False, size=width)
def check_module(module: ast.Module, f_import: Optional[ImportFn]) -> type_info.TypeInfo: """Validates type annotations on all functions within "module". Args: module: The module to type check functions for. f_import: Callback to import a module (a la a import statement). This may be None e.g. in unit testing situations where it's guaranteed there will be no import statements. Returns: Mapping from AST node to its deduced/checked type. Raises: XlsTypeError: If any of the function in f have typecheck errors. """ ti = type_info.TypeInfo(module) interpreter_callback = functools.partial(interpret_expr, f_import=f_import) ctx = deduce.DeduceCtx(ti, module, interpreter_callback, check_top_node_in_module) # First populate type_info with constants, enums, and resolved imports. ctx.fn_stack.append( ('top', dict())) # No sym bindings in the global scope. for member in ctx.module.top: if isinstance(member, ast.Import): assert isinstance(member.name, tuple), member.name imported_module, imported_type_info = f_import(member.name) ctx.type_info.add_import(member, (imported_module, imported_type_info)) elif isinstance(member, (ast.Constant, ast.Enum)): deduce.deduce(member, ctx) else: assert isinstance(member, (ast.Function, ast.Test, ast.Struct, ast.QuickCheck, ast.TypeDef)), member ctx.fn_stack.pop() quickcheck_map = { qc.f.name.identifier: qc for qc in ctx.module.get_quickchecks() } for qc in quickcheck_map.values(): assert isinstance(qc, ast.QuickCheck), qc f = qc.f assert isinstance(f, ast.Function), f if f.is_parametric(): # TODO(cdleary): 2020-08-09 See https://github.com/google/xls/issues/81 raise PositionalError( 'Quickchecking parametric ' 'functions is unsupported.', f.span) logging.vlog(2, 'Typechecking function: %s', f) ctx.fn_stack.append( (f.name.identifier, dict())) # No symbolic bindings. check_top_node_in_module(f, ctx) quickcheck_f_body_type = ctx.type_info[f.body] if quickcheck_f_body_type != ConcreteType.U1: raise XlsTypeError( f.span, quickcheck_f_body_type, ConcreteType.U1, suffix='QuickCheck functions must return a bool.') logging.vlog(2, 'Finished typechecking function: %s', f) # We typecheck struct definitions using check_top_node_in_module() so that # we can typecheck function calls in parametric bindings, if any. struct_map = {s.name.identifier: s for s in ctx.module.get_structs()} for s in struct_map.values(): assert isinstance(s, ast.Struct), s logging.vlog(2, 'Typechecking struct %s', s) ctx.fn_stack.append(('top', dict())) # No symbolic bindings. check_top_node_in_module(s, ctx) logging.vlog(2, 'Finished typechecking struct: %s', s) typedef_map = { t.name.identifier: t for t in ctx.module.top if isinstance(t, ast.TypeDef) } for t in typedef_map.values(): assert isinstance(t, ast.TypeDef), t logging.vlog(2, 'Typechecking typedef %s', t) ctx.fn_stack.append(('top', dict())) # No symbolic bindings. check_top_node_in_module(t, ctx) logging.vlog(2, 'Finished typechecking typedef: %s', t) function_map = {f.name.identifier: f for f in ctx.module.get_functions()} for f in function_map.values(): assert isinstance(f, ast.Function), f if f.is_parametric(): # Let's typecheck parametric functions per invocation. continue logging.vlog(2, 'Typechecking function: %s', f) ctx.fn_stack.append( (f.name.identifier, dict())) # No symbolic bindings. check_top_node_in_module(f, ctx) logging.vlog(2, 'Finished typechecking function: %s', f) test_map = {t.name.identifier: t for t in ctx.module.get_tests()} for t in test_map.values(): assert isinstance(t, ast.Test), t if isinstance(t, ast.TestFunction): # New-style test constructs are specified using a function. # This function shouldn't be parametric and shouldn't take any arguments. if t.fn.params: raise PositionalError( "Test functions shouldn't take arguments.", t.fn.span) if t.fn.is_parametric(): raise PositionalError( "Test functions shouldn't be parametric.", t.fn.span) # No symbolic bindings inside of a test. ctx.fn_stack.append(('{}_test'.format(t.name.identifier), dict())) logging.vlog(2, 'Typechecking test: %s', t) if isinstance(t, ast.TestFunction): # New-style tests are wrapped in a function. check_top_node_in_module(t.fn, ctx) else: # Old-style tests are specified in a construct with a body # (see check_test()). check_top_node_in_module(t, ctx) logging.vlog(2, 'Finished typechecking test: %s', t) return ctx.type_info