def map_size(t: ConcreteType, m: ast.Module, f: Callable[[Dim], Dim]) -> ConcreteType: """Runs f on all dimensions within t (transively for contained types).""" assert isinstance(m, ast.Module), m rec = functools.partial(map_size, m=m, f=f) if isinstance(t, ArrayType): return ArrayType(rec(t.get_element_type()), f(t.size)) elif isinstance(t, BitsType): return BitsType(t.signed, f(t.size)) elif isinstance(t, TupleType): nominal = t.get_nominal_type() if t.named: return TupleType( tuple((name, rec(type_)) for name, type_ in t.members), nominal) assert nominal is None, nominal return TupleType(tuple(rec(e) for e in t.members)) elif isinstance(t, EnumType): return EnumType(t.get_nominal_type(), f(t.size)) elif isinstance(t, FunctionType): mapped_params = tuple(rec(p) for p in t.params) mapped_return_type = rec(t.return_type) return FunctionType(mapped_params, mapped_return_type) else: raise NotImplementedError(t.__class__)
def _symbolic_bind_tuple(self, param_type: ConcreteType, arg_type: ConcreteType) -> None: """Binds any parametric symbols in the "tuple" param_type.""" assert isinstance(param_type, TupleType) and isinstance( arg_type, TupleType) for param_member, arg_member in zip(param_type.get_unnamed_members(), arg_type.get_unnamed_members()): self._symbolic_bind(param_member, arg_member)
def _symbolic_bind_array(self, param_type: ConcreteType, arg_type: ConcreteType) -> None: """Binds any parametric symbols in the "array" param_type.""" assert isinstance(param_type, ArrayType) and isinstance( arg_type, ArrayType) self._symbolic_bind(param_type.get_element_type(), arg_type.get_element_type()) self._symbolic_bind_dims(param_type, arg_type)
def _generate_unbiased_argument(concrete_type: ConcreteType, rng: ast_generator.RngState) -> Value: if isinstance(concrete_type, BitsType): bit_count = concrete_type.get_total_bit_count().value return _generate_bit_value(bit_count, rng, concrete_type.get_signedness()) else: raise NotImplementedError( 'Generate argument for type: {}'.format(concrete_type))
def concrete_type_accepts_value(module: ast.Module, type_: ConcreteType, value: Value) -> bool: """Returns whether 'value' conforms to this concrete type.""" if value.tag == Tag.UBITS: return (isinstance(type_, BitsType) and not type_.signed and value.bits_payload.bit_count == type_.get_total_bit_count()) if value.tag == Tag.SBITS: return (isinstance(type_, BitsType) and type_.signed and value.bits_payload.bit_count == type_.get_total_bit_count()) if value.tag in (Tag.ARRAY, Tag.TUPLE, Tag.ENUM): return _value_compatible_with_type(module, type_, value) raise NotImplementedError(type_, value)
def generate_argument(arg_type: ConcreteType, rng: ast_generator.RngState, prior: Sequence[Value]) -> Value: """Generates an argument value of the same type as the concrete type.""" if isinstance(arg_type, TupleType): return Value.make_tuple( tuple( generate_argument(t, rng, prior) for t in arg_type.get_unnamed_members())) elif isinstance(arg_type, ArrayType): return Value.make_array( tuple( generate_argument(arg_type.get_element_type(), rng, prior) for _ in range(arg_type.size.value))) else: assert isinstance(arg_type, BitsType) if not prior or rng.random() < 0.5: return _generate_unbiased_argument(arg_type, rng) # Try to mutate a prior argument. If it happens to not be a bits type then # just generate an unbiased argument. index = rng.randrange(len(prior)) if not prior[index].is_bits(): return _generate_unbiased_argument(arg_type, rng) to_mutate = prior[index].get_bits() bit_count = arg_type.get_total_bit_count().value if bit_count > to_mutate.bit_count(): addendum = _generate_bit_value(bit_count - to_mutate.bit_count(), rng, signed=False) assert addendum.get_bit_count() + to_mutate.bit_count() == bit_count to_mutate = to_mutate.concat(addendum.get_bits()) else: to_mutate = to_mutate.slice(0, bit_count) assert to_mutate.bit_count() == bit_count, (to_mutate.bit_count(), bit_count) value = to_mutate.to_uint() mutation_count = rng.randrange_biased_towards_zero(bit_count) for _ in range(mutation_count): # Pick a random bit and flip it. bitno = rng.randrange(bit_count) value ^= 1 << bitno signed = arg_type.get_signedness() tag = Tag.SBITS if signed else Tag.UBITS return Value.make_bits(tag, ir_bits.from_long(value=value, bit_count=bit_count))
def sign_convert_value(concrete_type: ConcreteType, value: Value) -> Value: """Converts the values to matched the signedness of the concrete type. Converts bits-typed Values contained within the given Value to match the signedness of the ConcreteType. Examples: invocation: sign_convert_value(s8, u8:64) returns: s8:64 invocation: sign_convert_value(s3, u8:7) returns: s3:-1 invocation: sign_convert_value((s8, u8), (u8:42, u8:10)) returns: (s8:42, u8:10) This conversion functionality is required because the Values used in the DSLX may be signed while Values in IR interpretation and Verilog simulation are always unsigned. This function is idempotent. Args: concrete_type: ConcreteType to match. value: Input value. Returns: Sign-converted value. """ if isinstance(concrete_type, concrete_type_mod.TupleType): assert value.is_tuple() assert len(value.get_elements()) == concrete_type.get_tuple_length() return Value.make_tuple( tuple( sign_convert_value(t, a) for t, a in zip( concrete_type.get_unnamed_members(), value.get_elements()))) elif isinstance(concrete_type, concrete_type_mod.ArrayType): assert value.is_array() assert len(value.get_elements()) == concrete_type.size return Value.make_array( tuple( sign_convert_value(concrete_type.get_element_type(), v) for v in value.get_elements())) elif concrete_type_mod.is_sbits(concrete_type): return Value.make_bits(Tag.SBITS, value.get_bits()) else: assert concrete_type_mod.is_ubits(concrete_type) return value
def _instantiate_one_arg(self, i: int, param_type: ConcreteType, arg_type: ConcreteType) -> ConcreteType: """Binds param_type via arg_type, updating symbolic bindings.""" assert isinstance(param_type, ConcreteType), repr(param_type) assert isinstance(arg_type, ConcreteType), repr(arg_type) # Check parameter and arg types are the same kind. if type(param_type) != type(arg_type): # pylint: disable=unidiomatic-typecheck raise XlsTypeError( self.span, param_type, arg_type, suffix='Parameter {} and argument types are different kinds ' '({} vs {}).'.format(i, param_type.get_debug_type_name(), arg_type.get_debug_type_name())) logging.vlog( 3, 'Symbolically binding param_type %d %s against arg_type %s', i, param_type, arg_type) self._symbolic_bind(param_type, arg_type) resolved = self._resolve(param_type) logging.vlog(3, 'Resolved param_type: %s', resolved) return resolved
def generate_argument(arg_type: ConcreteType, rng: Random, prior: Sequence[Value]) -> Value: """Generates an argument value of the same type as the concrete type.""" if isinstance(arg_type, TupleType): return Value.make_tuple( tuple( generate_argument(t, rng, prior) for t in arg_type.get_unnamed_members())) elif isinstance(arg_type, ArrayType): return Value.make_array( tuple( generate_argument(arg_type.get_element_type(), rng, prior) for _ in range(arg_type.size.value))) else: assert isinstance(arg_type, BitsType) if not prior or rng.random() < 0.5: return _generate_unbiased_argument(arg_type, rng) to_mutate = rng.choice(prior) bit_count = arg_type.get_total_bit_count().value if bit_count > to_mutate.get_bit_count(): to_mutate = to_mutate.bits_payload.concat( _generate_bit_value(bit_count - to_mutate.get_bit_count(), rng, signed=False).bits_payload) else: to_mutate = to_mutate.bits_payload.slice(0, bit_count, lsb_is_0=False) assert to_mutate.bit_count == bit_count value = to_mutate.value mutation_count = randrange_biased_towards_zero(bit_count, rng) for _ in range(mutation_count): # Pick a random bit and flip it. bitno = rng.randrange(bit_count) value ^= 1 << bitno signed = arg_type.get_signedness() constructor = Value.make_sbits if signed else Value.make_ubits return constructor(value=value, bit_count=bit_count)
def ir_value_to_interpreter_value(value: ir_value.Value, dslx_type: ConcreteType) -> dslx_value.Value: """Converts an IR Value to an interpreter Value.""" if value.is_bits(): assert isinstance(dslx_type, BitsType), dslx_type ir_bits_val = value.get_bits() if dslx_type.get_signedness(): return dslx_value.Value.make_sbits(ir_bits_val) return dslx_value.Value.make_ubits(ir_bits_val) elif value.is_array(): assert isinstance(dslx_type, ArrayType), dslx_type return dslx_value.Value.make_array( tuple( ir_value_to_interpreter_value(e, dslx_type.element_type) for e in value.get_elements())) else: assert value.is_tuple() assert isinstance(dslx_type, TupleType), dslx_type return dslx_value.Value.make_tuple( tuple( ir_value_to_interpreter_value(e, t) for e, t in zip( value.get_elements(), dslx_type.get_unnamed_members())))
def _symbolic_bind(self, param_type: ConcreteType, arg_type: ConcreteType) -> None: """Binds symbols present in param_type according to value of arg_type.""" assert isinstance(param_type, ConcreteType), repr(param_type) assert isinstance(arg_type, ConcreteType), repr(arg_type) if isinstance(param_type, BitsType): self._symbolic_bind_bits(param_type, arg_type) elif isinstance(param_type, EnumType): assert param_type.get_nominal_type( self.ctx.module) == arg_type.get_nominal_type(self.ctx.module) # If the enums are the same, we do the same thing as we do with bits # (ignore the primitive and symbolic bind the dims). self._symbolic_bind_bits(param_type, arg_type) elif isinstance(param_type, TupleType): param_nominal = param_type.get_nominal_type(self.ctx.module) arg_nominal = arg_type.get_nominal_type(self.ctx.module) logging.vlog(3, 'param nominal %s arg nominal %s', param_nominal, arg_nominal) if param_nominal != arg_nominal: raise XlsTypeError( self.span, param_type, arg_type, suffix='parameter type name: {}; argument type name: {}.'. format( repr(param_nominal.identifier) if param_nominal else '<none>', repr(arg_nominal.identifier) if arg_nominal else '<none>')) self._symbolic_bind_tuple(param_type, arg_type) elif isinstance(param_type, ArrayType): self._symbolic_bind_array(param_type, arg_type) elif isinstance(param_type, FunctionType): self._symbolic_bind_function(param_type, arg_type) else: raise NotImplementedError( 'Bind symbols in parameter type {} @ {}'.format( param_type, self.span))
def _value_compatible_with_type(module: ast.Module, type_: ConcreteType, value: Value) -> bool: """Returns whether value is compatible with type_ (recursively).""" assert isinstance(value, Value), value if isinstance(type_, TupleType) and value.is_tuple(): return all( _value_compatible_with_type(module, ct, m) for ct, m in zip(type_.get_unnamed_members(), value.tuple_members)) if isinstance(type_, ArrayType) and value.is_array(): et = type_.get_element_type() return all( _value_compatible_with_type(module, et, m) for m in value.array_payload.elements) if isinstance(type_, EnumType) and value.tag == Tag.ENUM: return type_.get_nominal_type(module) == value.type_ if isinstance(type_, BitsType) and not type_.signed and value.tag == Tag.UBITS: return value.bits_payload.bit_count == type_.get_total_bit_count() if isinstance(type_, BitsType) and type_.signed and value.tag == Tag.SBITS: return value.bits_payload.bit_count == type_.get_total_bit_count() if value.tag == Tag.ENUM and isinstance(type_, BitsType): return (value.type_.get_signedness() == type_.signed and value.bits_payload.bit_count == type_.get_total_bit_count()) if value.tag == Tag.ARRAY and is_ubits(type_): flat_bit_count = value.array_payload.flatten().bits_payload.bit_count return flat_bit_count == type_.get_total_bit_count() if isinstance(type_, EnumType) and value.is_bits(): return (type_.signed == (value.tag == Tag.SBITS) and type_.get_total_bit_count() == value.get_bit_count()) raise NotImplementedError(type_, value)
def concrete_type_convert_value(module: ast.Module, type_: ConcreteType, value: Value, span: Span, enum_values: Optional[Tuple[Value, ...]], enum_signed: Optional[bool]) -> Value: """Converts 'value' into a value of this concrete type.""" logging.vlog(3, 'Converting value %s to type %s', value, type_) if value.tag == Tag.UBITS and isinstance(type_, ArrayType): bits_per_element = type_.get_element_type().get_total_bit_count().value bits = value.bits_payload def bit_slice_value_at_index(i): return Value( Tag.UBITS, bits.slice(i * bits_per_element, (i + 1) * bits_per_element, lsb_is_0=False)) return Value.make_array( tuple( bit_slice_value_at_index(i) for i in range(type_.size.value))) if (isinstance(type_, EnumType) and value.tag in (Tag.UBITS, Tag.SBITS, Tag.ENUM) and value.get_bit_count() == type_.get_total_bit_count()): # Check that the bits we're converting from are present in the enum type # we're converting to. nominal_type = type_.get_nominal_type(module) for enum_value in enum_values: if value.bits_payload == enum_value.bits_payload: break else: raise FailureError( span, 'Value is not valid for enum {}: {}'.format( nominal_type.identifier, value)) return Value.make_enum(value.bits_payload, nominal_type) if (value.tag == Tag.ENUM and isinstance(type_, BitsType) and type_.get_total_bit_count() == value.get_bit_count()): constructor = Value.make_sbits if type_.signed else Value.make_ubits bit_count = type_.get_total_bit_count().value return constructor(bit_count, value.bits_payload.value) def zero_ext() -> Value: assert isinstance(type_, BitsType) constructor = Value.make_sbits if type_.signed else Value.make_ubits bit_count = type_.get_total_bit_count().value return constructor( bit_count, value.get_bits_value() & bit_helpers.to_mask(bit_count)) def sign_ext() -> Value: assert isinstance(type_, BitsType) constructor = Value.make_sbits if type_.signed else Value.make_ubits bit_count = type_.get_total_bit_count().value logging.vlog(3, 'Sign extending %s to %s', value, bit_count) return constructor(bit_count, value.bits_payload.sign_ext(bit_count).value) if value.tag == Tag.UBITS: return zero_ext() if value.tag == Tag.SBITS: return sign_ext() if value.tag == Tag.ENUM: assert enum_signed is not None return sign_ext() if enum_signed else zero_ext() # If we're converting an array into bits, flatten the array payload. if value.tag == Tag.ARRAY and isinstance(type_, BitsType): return value.array_payload.flatten() if concrete_type_accepts_value(module, type_, value): # Vacuous conversion. return value raise FailureError( span, 'Interpreter failure: cannot convert value %s (of type %s) to type %s' % (value, concrete_type_from_value(value), type_))