Пример #1
0
  def _visit_map(self, node: ast.Invocation) -> BValue:
    for arg in node.args[:-1]:
      visit(arg, self)
    arg = self._use(node.args[0])
    fn_node = node.args[1]

    if isinstance(fn_node, ast.NameRef):
      map_fn_name = fn_node.name_def.identifier
      if map_fn_name in dslx_builtins.PARAMETRIC_BUILTIN_NAMES:
        return self._def_map_with_builtin(node, fn_node, node.args[0],
                                          self._get_invocation_bindings(node))
      else:
        lookup_module = self.module
        fn = lookup_module.get_function(map_fn_name)
    elif isinstance(fn_node, ast.ModRef):
      map_fn_name = fn_node.value
      imports = self.type_info.get_imports()
      lookup_module, _ = imports[fn_node.mod]
      fn = lookup_module.get_function(map_fn_name)
    else:
      raise NotImplementedError(
          'Unhandled function mapping: {!r}'.format(fn_node))

    node_sym_bindings = self._get_invocation_bindings(node)
    mangled_name = mangle_dslx_name(fn.name, fn.get_free_parametric_keys(),
                                    lookup_module, node_sym_bindings)

    return self._def(node, self.fb.add_map, arg,
                     self.package.get_function(mangled_name))
Пример #2
0
 def _visit_matcher(self, matcher: ast.NameDefTree, index: Tuple[int, ...],
                    matched_value: BValue,
                    matched_type: ConcreteType) -> BValue:
   if matcher.is_leaf():
     leaf = matcher.get_leaf()
     logging.vlog(5, 'Matcher is leaf: %s (%s)', leaf, leaf.__class__.__name__)
     if isinstance(leaf, ast.WildcardPattern):
       return self._def(matcher, self.fb.add_literal_bits,
                        bits_mod.UBits(1, 1))
     elif isinstance(leaf, (ast.Number, ast.EnumRef)):
       visit(leaf, self)
       return self._def(matcher, self.fb.add_eq, self._use(leaf),
                        matched_value)
     elif isinstance(leaf, ast.NameRef):
       result = self._def(matcher, self.fb.add_eq, self._use(leaf.name_def),
                          matched_value)
       self._def_alias(leaf.name_def, to=leaf)
       return result
     else:
       assert isinstance(
           leaf, ast.NameDef
       ), 'Expected leaf to be wildcard, number, or name; got: {!r}'.format(
           leaf)
       ok = self._def(leaf, self.fb.add_literal_bits, bits_mod.UBits(1, 1))
       self.node_to_ir[matcher] = self.node_to_ir[leaf] = matched_value
       return ok
   else:
     ok = self.fb.add_literal_bits(bits_mod.UBits(value=1, bit_count=1))
     for i, (element, element_type) in enumerate(
         zip(matcher.tree, matched_type.get_unnamed_members())):  # pytype: disable=attribute-error
       # Extract the element.
       member = self.fb.add_tuple_index(matched_value, i)
       cond = self._visit_matcher(element, index + (i,), member, element_type)
       ok = self.fb.add_and(ok, cond)
     return ok
Пример #3
0
 def query_const_range_call() -> int:
   """Returns trip count if this is a `for ... in range(CONST)` construct."""
   range_callee = (
       isinstance(node.iterable, ast.Invocation) and
       isinstance(node.iterable.callee, ast.NameRef) and
       node.iterable.callee.identifier == 'range')
   if not range_callee:
     raise ConversionError(
         'For-loop is of an unsupported form for IR conversion; only a '
         "'range(0, const)' call is supported, found non-range callee.",
         node.span)
   if len(node.iterable.args) != 2:
     raise ConversionError(
         'For-loop is of an unsupported form for IR conversion; only a '
         "'range(0, const)' call is supported, found inappropriate number "
         'of arguments.', node.span)
   if not self._is_constant_zero(node.iterable.args[0]):
     raise ConversionError(
         'For-loop is of an unsupported form for IR conversion; only a '
         "'range(0, const)' call is supported, found inappropriate number "
         'of arguments.', node.span)
   arg = node.iterable.args[1]
   visit(arg, self)
   if not self._is_const(arg):
     raise ConversionError(
         'For-loop is of an unsupported form for IR conversion; only a '
         "'range(const)' call is supported, did not find a const value "
         f'for {arg} ({arg!r}).', node.span)
   return self._get_const(arg)
Пример #4
0
 def test_simple_number(self):
     m = ast.Module('test')
     n = ast.Number(m, self.fake_span, '42')
     self.assertEmpty(n.children)
     collector = Collector()
     cpp_ast_visitor.visit(n, collector)
     self.assertEqual(collector.visited, [n])
Пример #5
0
 def visit_StructInstance(self, node: ast.StructInstance) -> None:
   operands = []
   struct = self._deref_struct(node.struct)
   for _, m in node.get_ordered_members(struct):
     visit(m, self)
     operands.append(self._use(m))
   operands = tuple(operands)
   self._def(node, self.fb.add_tuple, operands)
Пример #6
0
 def test_array_of_numbers(self):
     m = ast.Module('test')
     n0 = ast.Number(m, self.fake_span, '42')
     n1 = ast.Number(m, self.fake_span, '64')
     a = ast.Array(m, self.fake_span, [n0, n1], False)
     collector = Collector()
     cpp_ast_visitor.visit(a, collector)
     self.assertEqual(collector.visited, [n0, n1, a])
Пример #7
0
    def test_visit_unop(self):
        fake_span = self.fake_span
        m = self.m
        i_def = ast.NameDef(m, fake_span, 'i')
        i_ref = ast.NameRef(m, fake_span, 'i', i_def)
        negated = ast.Unop(m, fake_span, ast.UnopKind.NEG, i_ref)

        c = _Collector()
        visit(negated, c)
        self.assertEqual(c.collected, [i_ref, negated])
Пример #8
0
 def visit_Array(self, node: ast.Array) -> None:
   array_type = self._resolve_type(node)
   members = []
   for member in node.members:
     visit(member, self)
     members.append(self._use(member))
   if node.has_ellipsis:
     while len(members) < array_type.size:  # pytype: disable=attribute-error
       members.append(members[-1])
   self._def(node, self.fb.add_array, members, members[0].get_type())
Пример #9
0
    def test_visit_index(self):
        fake_span = self.fake_span
        m = self.m
        # Make a t[i] inde xnode.
        t = ast.NameRef(m, fake_span, 't', ast.NameDef(m, fake_span, 't'))
        i = ast.NameRef(m, fake_span, 'i', ast.NameDef(m, fake_span, 'i'))
        index = ast.Index(m, fake_span, t, i)

        c = _Collector()
        visit(index, c)
        self.assertEqual(c.collected, [t, i, index])
Пример #10
0
 def test_visit_match_multi_pattern(self):
     fake_pos = self.fake_pos
     fake_span = Span(fake_pos, fake_pos)
     m = self.m
     e = ast.Number(m, fake_span, u'0xf00')
     p0 = ast.NameDefTree(m, fake_span, e)
     p1 = ast.NameDefTree(m, fake_span, e)
     arm = ast.MatchArm(m, fake_span, patterns=(p0, p1), expr=e)
     c = _Collector()
     visit(arm, c)
     self.assertEqual(c.collected, [e, e, e])
Пример #11
0
    def test_visit_type(self):
        fake_span = self.fake_span
        five = self.five
        # Make a uN[5] type node.
        t = ast_helpers.make_builtin_type_annotation(self.m,
                                                     fake_span,
                                                     Token(value=Keyword.BITS,
                                                           span=fake_span),
                                                     dims=(five, ))
        assert isinstance(t, ast.ArrayTypeAnnotation), t

        c = _Collector()
        visit(t, c)
        self.assertEqual(c.collected, [five, t])
Пример #12
0
  def visit_SplatStructInstance(self, node: ast.SplatStructInstance) -> None:
    visit(node.splatted, self)
    orig = self._use(node.splatted)
    updates = {}
    for k, e in node.members:
      visit(e, self)
      updates[k] = self._use(e)
    struct = self._deref_struct(node.struct)

    members = []
    for i, k in enumerate(struct.member_names):
      if k in updates:
        members.append(updates[k])
      else:
        members.append(self.fb.add_tuple_index(orig, i))
    self._def(node, self.fb.add_tuple, members)
Пример #13
0
 def visit_Cast(self, node: ast.Cast) -> None:
   visit(node.expr, self)
   output_type = self._resolve_type(node)
   if isinstance(output_type, ArrayType):
     return self._cast_to_array(node, output_type)
   if not (isinstance(output_type, BitsType) or
           isinstance(output_type, EnumType)):
     raise NotImplementedError(
         'Cast can only handle bits output types; got: '
         f'{output_type} @ {node.span} ({output_type!r})')
   input_type = self._resolve_type(node.expr)
   if isinstance(input_type, ArrayType):
     return self._cast_from_array(node, output_type)
   new_bit_count = output_type.get_total_bit_count()
   input_type = self._resolve_type(node.expr)
   if new_bit_count < input_type.get_total_bit_count():
     self._def(node, self.fb.add_bit_slice, self._use(node.expr), 0,
               new_bit_count)
   else:
     signed_input = input_type.get_signedness()
     f = self.fb.add_signext if signed_input else self.fb.add_zeroext
     self._def(node, f, self._use(node.expr), new_bit_count)
Пример #14
0
  def visit_Match(self, node: ast.Match):
    if (not node.arms or not node.arms[-1].patterns[0].is_irrefutable()):
      raise ConversionError(
          'Only matches with trailing irrefutable patterns are currently handled.',
          node.span)

    visit(node.matched, self)
    matched = self._use(node.matched)
    matched_type = self._resolve_type(node.matched)

    default_arm = node.arms[-1]
    assert len(default_arm.patterns) == 1, (
        'Multiple patterns in default arm is not yet implemented for IR '
        'conversion.')
    self._visit_matcher(default_arm.patterns[0], (len(node.arms) - 1,), matched,
                        matched_type)
    visit(default_arm.expr, self)

    arm_selectors = []
    arm_values = []
    for i, arm in enumerate(node.arms[:-1]):
      # Visit all the match patterns.
      this_arm_selectors = []
      for pattern in arm.patterns:
        selector = self._visit_matcher(pattern, (i,), matched, matched_type)
        this_arm_selectors.append(selector)

      # "Or" together the patterns, if necessary, to determine if the arm is
      # selected.
      if len(this_arm_selectors) > 1:
        arm_selectors.append(self.fb.add_nary_or(this_arm_selectors))
      else:
        arm_selectors.append(this_arm_selectors[0])
      visit(arm.expr, self)
      arm_values.append(self._use(arm.expr))

    # So now we have the following representation of the match arms:
    #   match x {
    #     42  => blah
    #     64  => snarf
    #     128 => yep
    #     _   => burp
    #   }
    #
    #   selectors:     [x==42, x==64, x==128]
    #   values:        [blah,  snarf,    yep]
    #   default_value: burp
    self.node_to_ir[node] = self.fb.add_match_true(arm_selectors, arm_values,
                                                   self._use(default_arm.expr))
    self.last_expression = node
Пример #15
0
  def _visit_Function(
      self, node: ast.Function,
      symbolic_bindings: Optional[SymbolicBindings]) -> ir_function.Function:
    self.symbolic_bindings = {} if symbolic_bindings is None else dict(
        symbolic_bindings)
    self._extract_module_level_constants(self.module)
    # We use a function builder for the duration of converting this
    # ast.Function. When it's done being built, we drop the reference to it (by
    # setting self.fb to None).
    self.fb = function_builder.FunctionBuilder(
        mangle_dslx_name(node.name.identifier, node.get_free_parametric_keys(),
                         self.module, symbolic_bindings), self.package)
    try:
      for param in node.params:
        visit(param, self)

      for parametric_binding in node.parametric_bindings:
        logging.vlog(4, 'Resolving parametric binding %s', parametric_binding)

        sb_value = self.symbolic_bindings[parametric_binding.name.identifier]
        value = self._resolve_dim(sb_value)
        assert isinstance(value, int), \
            'Expect integral parametric binding; got {!r}'.format(value)
        self._def_const(
            parametric_binding, value,
            self._resolve_type(parametric_binding.type_).get_total_bit_count())
        self._def_alias(parametric_binding, to=parametric_binding.name)

      for dep in self._constant_deps:
        visit(dep, self)
      del self._constant_deps[:]

      visit(node.body, self)

      last_expression = self.last_expression or node.body
      if isinstance(last_expression, ast.NameRef):
        self._def(last_expression, self.fb.add_identity,
                  self._use(last_expression))
      f = self.fb.build()
      logging.vlog(3, 'Built function: %s', f.name)
      verifier_mod.verify_function(f)
      return f
    finally:
      self.fb = None
Пример #16
0
  def visit_Index(self, node: ast.Index) -> None:
    visit(node.lhs, self)
    lhs_type = self.type_info[node.lhs]
    if isinstance(lhs_type, TupleType):
      visit(node.index, self)
      self._def(node, self.fb.add_tuple_index, self._use(node.lhs),
                self._get_const(node.index))
    elif isinstance(lhs_type, BitsType):
      index_slice = node.index
      if isinstance(index_slice, ast.WidthSlice):
        return self._visit_width_slice(node, index_slice, lhs_type)
      assert isinstance(index_slice, ast.Slice), index_slice

      start, width = self.type_info.get_slice_start_width(
          index_slice, self._get_symbolic_bindings_tuple())
      self._def(node, self.fb.add_bit_slice, self._use(node.lhs), start, width)
    else:
      visit(node.index, self)
      self._def(node, self.fb.add_array_index, self._use(node.lhs),
                self._use(node.index))
Пример #17
0
  def visit_Let(self, node: ast.Let):
    visit(node.rhs, self)
    if node.name_def_tree.is_leaf():
      self._def_alias(node.rhs, to=node.name_def_tree.get_leaf())
      visit(node.body, self)
      self._def_alias(node.body, node)
    else:
      # Walk the tree performing tuple_index operations to get to the binding
      # levels desired.

      names = [self._use(node.rhs)]  # List[BValue]

      def walk(x: ast.NameDefTree, level: int, index: int) -> None:
        """Invoked at each level of the name def tree.

        Binds the name in the name def tree to the corresponding value being
        pattern matched.

        Args:
          x: The current level of the NameDefTree.
          level: Level in the NameDefTree (root is 0).
          index: Index of node in the current tree level (e.g. leftmost is 0).
        """
        del names[level:]
        names.append(
            self._def(
                x,
                self.fb.add_tuple_index,
                names[-1],
                index,
                span=(x.get_leaf().span if x.is_leaf() else x.span)))
        if x.is_leaf():
          self._def_alias(x, x.get_leaf())

      ast_helpers.do_preorder(node.name_def_tree, walk)
      visit(node.body, self)
      self._def_alias(node.body, to=node)

    if self.last_expression is None:
      self.last_expression = node.body
Пример #18
0
 def visit_EnumRef(self, node: ast.EnumRef) -> None:
   enum = self._deref_enum(node.enum)
   value = enum.get_value(node.value)
   visit(value, self)
   self._def_alias(from_=value, to=node)
Пример #19
0
 def visit_XlsTuple(self, node: ast.XlsTuple) -> None:
   for o in node.members:
     visit(o, self)
   operands = tuple(self._use(o) for o in node.members)
   self._def(node, self.fb.add_tuple, operands)
Пример #20
0
def get_callees(func: Union[ast.Function, ast.Test], m: ast.Module,
                type_info: type_info_mod.TypeInfo, imports: Dict[ast.Import,
                                                                 ImportedInfo],
                bindings: SymbolicBindings) -> Tuple[Callee, ...]:
    """Traverses the definition of f to find callees.

  Args:
    func: Function/test construct to inspect for calls.
    m: Module that f resides in.
    type_info: Node to type mapping that should be used with f.
    imports: Mapping of modules imported by m.
    bindings: Bindings used in instantiation of f.

  Returns:
    Callee functions invoked by f, and the parametric bindings used in each of
    those invocations.
  """
    assert isinstance(bindings, SymbolicBindings), bindings
    callees = []

    class InvocationVisitor(cpp_ast_visitor.AstVisitor):
        """Visits invocation nodes to build up the callees list."""
        @cpp_ast_visitor.AstVisitor.no_auto_traverse
        def visit_ParametricBinding(self, node: ast.ParametricBinding) -> None:
            pass

        def visit_Invocation(self, node: ast.Invocation) -> None:
            if isinstance(node.callee, ast.ColonRef):
                this_m, _ = imports[node.callee.subject.name_def.definer]
                f = this_m.get_function(node.callee.attr)
                fn_identifier = f.identifier
            elif isinstance(node.callee, ast.NameRef):
                this_m = m
                fn_identifier = node.callee.identifier
                if fn_identifier == 'map':
                    # We need to make sure we convert the mapped function!
                    fn_node = node.args[1]
                    if isinstance(fn_node, ast.ColonRef):
                        fn_identifier = fn_node.attr
                        import_node = fn_node.subject.name_def.definer
                        this_m = imports[import_node][0]
                    else:
                        fn_identifier = fn_node.name_def.identifier
                try:
                    f = this_m.get_function(fn_identifier)
                except KeyError:
                    if node.callee.identifier in dslx_builtins.PARAMETRIC_BUILTIN_NAMES:
                        return
                    raise
            else:
                raise NotImplementedError(
                    'Only calls to named functions are currently supported, got callee: {!r}'
                    .format(node.callee))

            node_symbolic_bindings = type_info.get_invocation_symbolic_bindings(
                node, bindings)

            # Either use the global type_info or the child type_info
            # chained off of this invocation.
            try:
                invocation_type_info = type_info.get_instantiation(
                    node, node_symbolic_bindings)
            except KeyError:
                invocation_type_info = type_info
            assert invocation_type_info is not None
            assert isinstance(node_symbolic_bindings,
                              SymbolicBindings), node_symbolic_bindings
            callees.append(
                Callee(f, this_m, invocation_type_info,
                       node_symbolic_bindings))

    cpp_ast_visitor.visit(func, InvocationVisitor())
    logging.vlog(3, 'Callees for %s: %s', func,
                 [(cr.m.name, cr.f.identifier, cr.sym_bindings)
                  for cr in callees])
    return tuple(callees)
Пример #21
0
  def visit_For(self, node: ast.For) -> None:
    visit(node.init, self)

    def query_const_range_call() -> int:
      """Returns trip count if this is a `for ... in range(CONST)` construct."""
      range_callee = (
          isinstance(node.iterable, ast.Invocation) and
          isinstance(node.iterable.callee, ast.NameRef) and
          node.iterable.callee.identifier == 'range')
      if not range_callee:
        raise ConversionError(
            'For-loop is of an unsupported form for IR conversion; only a '
            "'range(0, const)' call is supported, found non-range callee.",
            node.span)
      if len(node.iterable.args) != 2:
        raise ConversionError(
            'For-loop is of an unsupported form for IR conversion; only a '
            "'range(0, const)' call is supported, found inappropriate number "
            'of arguments.', node.span)
      if not self._is_constant_zero(node.iterable.args[0]):
        raise ConversionError(
            'For-loop is of an unsupported form for IR conversion; only a '
            "'range(0, const)' call is supported, found inappropriate number "
            'of arguments.', node.span)
      arg = node.iterable.args[1]
      visit(arg, self)
      if not self._is_const(arg):
        raise ConversionError(
            'For-loop is of an unsupported form for IR conversion; only a '
            "'range(const)' call is supported, did not find a const value "
            f'for {arg} ({arg!r}).', node.span)
      return self._get_const(arg)

    # TODO(leary): We currently only support counted loops of the form:
    #
    #   for (i, ...): (u32, ...) in range(N) {
    #      ...
    #   }
    trip_count = query_const_range_call()

    logging.vlog(3, 'Converting for-loop @ %s', node.span)
    body_converter = _IrConverterFb(
        self.package,
        self.module,
        self.type_info,
        emit_positions=self.emit_positions)
    body_converter.symbolic_bindings = dict(self.symbolic_bindings)
    body_fn_name = ('__' + self.fb.name + '_counted_for_{}_body').format(
        self._next_counted_for_ordinal()).replace('.', '_')
    body_converter.fb = function_builder.FunctionBuilder(
        body_fn_name, self.package)
    flat = node.names.flatten1()
    assert len(
        flat
    ) == 2, 'Expect an induction binding and loop carry binding; got {!r}'.format(
        flat)

    # Add the induction value.
    assert isinstance(
        flat[0], ast.NameDef
    ), 'Induction variable was not a NameDef: {0} ({0!r})'.format(flat[0])
    body_converter.node_to_ir[flat[0]] = body_converter.fb.add_param(
        flat[0].identifier.encode('utf-8'), self._resolve_type_to_ir(flat[0]))

    # Add the loop carry value.
    if isinstance(flat[1], ast.NameDef):
      body_converter.node_to_ir[flat[1]] = body_converter.fb.add_param(
          flat[1].identifier.encode('utf-8'), self._resolve_type_to_ir(flat[1]))
    else:
      # For tuple loop carries we have to destructure names on entry.
      carry_type = self._resolve_type_to_ir(flat[1])
      carry = body_converter.node_to_ir[flat[1]] = body_converter.fb.add_param(
          '__loop_carry', carry_type)
      body_converter._visit_matcher(  # pylint: disable=protected-access
          flat[1], (), carry, self._resolve_type(flat[1]))

    # Free variables are suffixes on the function parameters.
    freevars = node.body.get_free_variables(node.span.start)
    freevars = freevars.drop_builtin_defs()
    for name_def in freevars.get_name_defs(self.module):
      type_ = self.type_info[name_def]
      if isinstance(type_, FunctionType):
        continue
      logging.vlog(3, 'Converting freevar name: %s', name_def)
      body_converter.node_to_ir[name_def] = body_converter.fb.add_param(
          name_def.identifier.encode('utf-8'),
          self._resolve_type_to_ir(name_def))

    visit(node.body, body_converter)
    body_function = body_converter.fb.build()
    logging.vlog(3, 'Converted body function: %s', body_function.name)

    stride = 1
    invariant_args = tuple(
        self._use(name_def)
        for name_def in freevars.get_name_defs(self.module)
        if not isinstance(self.type_info[name_def], FunctionType))
    self._def(node, self.fb.add_counted_for, self._use(node.init), trip_count,
              stride, body_function, invariant_args)
Пример #22
0
 def _visit_width_slice(self, node: ast.Index, width_slice: ast.WidthSlice,
                        lhs_type: ConcreteType) -> None:
   visit(width_slice.start, self)
   self._def(node, self.fb.add_dynamic_bit_slice, self._use(node.lhs),
             self._use(width_slice.start),
             self._resolve_type(node).get_total_bit_count())
Пример #23
0
 def visit_Constant(self, node: ast.Constant) -> None:
   visit(node.value, self)
   logging.vlog(5, 'Aliasing NameDef for constant: %r', node.name)
   self._def_alias(node.value, to=node.name)
Пример #24
0
 def accept_args() -> Tuple[BValue, ...]:
   for arg in node.args:
     visit(arg, self)
   return tuple(self._use(arg) for arg in node.args)