Example #1
0
 def _get_callee_identifier(self, node: ast.Invocation) -> Text:
   logging.vlog(3, 'Getting callee identifier for invocation: %s', node)
   if isinstance(node.callee, ast.NameRef):
     callee_name = node.callee.identifier
     m = self.module
   elif isinstance(node.callee, ast.ModRef):
     m = self.type_info.get_imports()[node.callee.mod][0]
     callee_name = node.callee.value
   else:
     raise NotImplementedError('Callee not currently supported @ {}'.format(
         node.span))
   try:
     function = m.get_function(callee_name)
   except KeyError:
     # For e.g. builtins that are not in the module we just provide the name
     # directly.
     return callee_name
   if not function.is_parametric():
     return mangle_dslx_name(function.name.identifier,
                             function.get_free_parametric_keys(), m, None)
   resolved_symbolic_bindings = self._get_invocation_bindings(node)
   logging.vlog(2, 'Node %s @ %s symbolic bindings %r', node, node.span,
                resolved_symbolic_bindings)
   assert resolved_symbolic_bindings, node
   return mangle_dslx_name(function.name.identifier,
                           function.get_free_parametric_keys(), m,
                           resolved_symbolic_bindings)
Example #2
0
  def _visit_map(self, node: ast.Invocation) -> BValue:
    for arg in node.args[:-1]:
      arg.accept(self)
    arg = self._use(node.args[0])
    fn_node = node.args[1]

    if isinstance(fn_node, ast.NameRef):
      map_fn_name = fn_node.name_def.identifier
      if map_fn_name in dslx_builtins.PARAMETRIC_BUILTIN_NAMES:
        return self._def_map_with_builtin(node, fn_node, node.args[0],
                                          self._get_invocation_bindings(node))
      else:
        lookup_module = self.module
        fn = lookup_module.get_function(map_fn_name)
    elif isinstance(fn_node, ast.ModRef):
      map_fn_name = fn_node.value
      imports = self.type_info.get_imports()
      lookup_module, _ = imports[fn_node.mod]
      fn = lookup_module.get_function(map_fn_name)
    else:
      raise NotImplementedError(
          'Unhandled function mapping: {!r}'.format(fn_node))

    node_sym_bindings = self._get_invocation_bindings(node)
    mangled_name = mangle_dslx_name(fn.name, fn.get_free_parametric_keys(),
                                    lookup_module, node_sym_bindings)

    return self._def(node, self.fb.add_map, arg,
                     self.package.get_function(mangled_name))
Example #3
0
  def _visit_Function(
      self, node: ast.Function,
      symbolic_bindings: Optional[SymbolicBindings]) -> ir_function.Function:
    self.symbolic_bindings = {} if symbolic_bindings is None else dict(
        symbolic_bindings)
    self._extract_module_level_constants(self.module)
    # We use a function builder for the duration of converting this
    # ast.Function. When it's done being built, we drop the reference to it (by
    # setting self.fb to None).
    self.fb = function_builder.FunctionBuilder(
        mangle_dslx_name(node.name.identifier, node.get_free_parametric_keys(),
                         self.module, symbolic_bindings), self.package)
    try:
      for param in node.params:
        param.accept(self)

      for parametric_binding in node.parametric_bindings:
        logging.vlog(4, 'Resolving parametric binding %s', parametric_binding)

        sb_value = self.symbolic_bindings[parametric_binding.name.identifier]
        value = self._resolve_dim(sb_value)
        assert isinstance(value, int), \
            'Expect integral parametric binding; got {!r}'.format(value)
        self._def_const(
            parametric_binding, value,
            self._resolve_type(parametric_binding.type_).get_total_bit_count())
        self._def_alias(parametric_binding, to=parametric_binding.name)

      for dep in self._constant_deps:
        dep.accept(self)
      del self._constant_deps[:]

      node.body.accept(self)

      last_expression = self.last_expression or node.body
      if isinstance(last_expression, ast.NameRef):
        self._def(last_expression, self.fb.add_identity,
                  self._use(last_expression))
      f = self.fb.build()
      logging.vlog(3, 'Built function: %s', f.name)
      verifier_mod.verify_function(f)
      return f
    finally:
      self.fb = None
Example #4
0
  def _def_map_with_builtin(self, parent_node: ast.Invocation,
                            node: ast.NameRef, arg: ast.AstNode,
                            symbolic_bindings: SymbolicBindings) -> BValue:
    """Makes the specified builtin available to the package."""
    mangled_name = mangle_dslx_name(node.name_def.identifier, set(),
                                    self.module, symbolic_bindings)

    arg = self._use(arg)
    if mangled_name not in self.package.get_function_names():
      fb = function_builder.FunctionBuilder(mangled_name, self.package)
      param = fb.add_param('arg', arg.get_type().get_element_type())
      builtin_name = node.name_def.identifier
      assert builtin_name in dslx_builtins.UNARY_BUILTIN_NAMES, dslx_builtins.UNARY_BUILTIN_NAMES
      fbuilds = {'clz': fb.add_clz, 'ctz': fb.add_ctz}
      assert set(fbuilds.keys()) == dslx_builtins.UNARY_BUILTIN_NAMES, set(
          fbuilds.keys())
      fbuilds[builtin_name](param)
      fb.build()
    return self._def(parent_node, self.fb.add_map, arg,
                     self.package.get_function(mangled_name))