예제 #1
0
    def visit_Call(self, node):
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')

            if anno.hasanno(node.func, 'fqn'):
                target_fqn = anno.getanno(node.func, 'fqn')
            else:
                target_fqn = None

            if self._function_is_compilable(target_entity):
                if self._should_compile(node, target_fqn):
                    node = self._rename_compilable_function(node)
                else:
                    node = self.generic_visit(node)
                    return node

            elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
                # TODO(mdan): Should we replace these with equivalent TF ops instead?
                node = self._wrap_to_py_func_single_return(
                    node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)

            elif inspect_utils.isbuiltin(target_entity):
                # Note: Any builtin that passed the builtins converter is assumed to be
                # safe for graph mode.
                return node

            elif inspect_utils.isnamedtuple(target_entity):
                # Although not compilable, we assume they are safe for graph mode.
                node = self.generic_visit(node)
                return node

            else:
                # TODO(mdan): Instert dynamic conversion here instead.
                raise NotImplementedError(
                    'py_func with return values (unknown function)')
        else:
            # Special cases
            # TODO(mdan): These need a systematic review - there may be more.

            # 1. super() calls - these are preserved. The class conversion mechanism
            # will ensure that they return the correct value.
            if ast_util.matches(node, parser.parse_expression('super(_)')):
                return node

            # 2. super().method calls - these are preserved as well, when the
            # conversion processes the entire class.
            if (ast_util.matches(node,
                                 parser.parse_expression('super(_)._(_)'))
                    and self.ctx.info.owner_type is not None):
                return node

            node = self._insert_dynamic_conversion(node)
        return node
예제 #2
0
  def visit_Call(self, node):
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')

      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None

      if self._function_is_compilable(target_entity):
        if self._should_compile(node, target_fqn):
          node = self._rename_compilable_function(node)
        else:
          node = self.generic_visit(node)
          return node

      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)

      elif inspect_utils.isbuiltin(target_entity):
        # Note: Any builtin that passed the builtins converter is assumed to be
        # safe for graph mode.
        return node

      elif inspect_utils.isnamedtuple(target_entity):
        # Although not compilable, we assume they are safe for graph mode.
        node = self.generic_visit(node)
        return node

      else:
        # TODO(mdan): Instert dynamic conversion here instead.
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      # Special cases
      # TODO(mdan): These need a systematic review - there may be more.

      # 1. super() calls - these are preserved. The class conversion mechanism
      # will ensure that they return the correct value.
      if ast_util.matches(node, 'super(_)'):
        return node

      # 2. super().method calls - these are preserved as well, when the
      # conversion processes the entire class.
      if (ast_util.matches(node, 'super(_)._(_)') and
          self.ctx.info.owner_type is not None):
        return node

      node = self._insert_dynamic_conversion(node)
    return node
예제 #3
0
  def assertAstMatches(self, actual_node, expected_node_src):
    expected_node = gast.parse(expected_node_src).body[0]

    msg = 'AST did not match expected:\n{}\nActual:\n{}'.format(
        pretty_printer.fmt(expected_node),
        pretty_printer.fmt(actual_node))
    self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
예제 #4
0
    def _visit_and_reindent(self, nodes):
        new_nodes = []
        current_dest = new_nodes
        alias_map = {}
        reindent_requested = False
        for n in nodes:
            n = self.visit(n)
            # NOTE: the order in which these statements execute is important; in
            # particular, watch out for ending up with cycles in the AST.
            if alias_map:
                n = ast_util.rename_symbols(n, alias_map)
            if isinstance(n, (list, tuple)):
                current_dest.extend(n)
            else:
                current_dest.append(n)
            if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
                reindent_requested = True
                new_dest, new_alias_map = anno.getanno(
                    n, anno.Basic.INDENT_BLOCK_REMAINDER)
                anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
                new_alias_map.update(alias_map)
                alias_map = new_alias_map
                current_dest = new_dest

        if reindent_requested:
            no_controls_to_gate = False
            if not current_dest:
                no_controls_to_gate = True
            if len(current_dest) == 1:
                if ast_util.matches(current_dest[0], 'return'):
                    no_controls_to_gate = True
                if ast_util.matches(current_dest[0], 'return ()'):
                    no_controls_to_gate = True
                if ast_util.matches(current_dest[0], 'return []'):
                    no_controls_to_gate = True
                if ast_util.matches(current_dest[0], 'return {}'):
                    no_controls_to_gate = True
            if no_controls_to_gate:
                # TODO(mdan): There may still be something that could be done.
                raise ValueError(
                    'Unable to insert statement into the computation flow: it is not'
                    ' followed by any computation which the statement could gate.'
                )

        return new_nodes
예제 #5
0
  def _visit_and_reindent(self, nodes):
    new_nodes = []
    current_dest = new_nodes
    alias_map = {}
    reindent_requested = False
    for n in nodes:
      n = self.visit(n)
      # NOTE: the order in which these statements execute is important; in
      # particular, watch out for ending up with cycles in the AST.
      if alias_map:
        n = ast_util.rename_symbols(n, alias_map)
      if isinstance(n, (list, tuple)):
        current_dest.extend(n)
      else:
        current_dest.append(n)
      if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
        reindent_requested = True
        new_dest, new_alias_map = anno.getanno(
            n, anno.Basic.INDENT_BLOCK_REMAINDER)
        anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
        new_alias_map.update(alias_map)
        alias_map = new_alias_map
        current_dest = new_dest

    if reindent_requested:
      no_controls_to_gate = False
      if not current_dest:
        no_controls_to_gate = True
      if len(current_dest) == 1:
        if ast_util.matches(current_dest[0], 'return'):
          no_controls_to_gate = True
        if ast_util.matches(current_dest[0], 'return ()'):
          no_controls_to_gate = True
        if ast_util.matches(current_dest[0], 'return []'):
          no_controls_to_gate = True
        if ast_util.matches(current_dest[0], 'return {}'):
          no_controls_to_gate = True
      if no_controls_to_gate:
        # TODO(mdan): There may still be something that could be done.
        raise ValueError(
            'Unable to insert statement into the computation flow: it is not'
            ' followed by any computation which the statement could gate.')

    return new_nodes
예제 #6
0
  def assertAstMatches(self, actual_node, expected_node_src, expr=True):
    if expr:
      # Ensure multi-line expressions parse.
      expected_node = gast.parse('({})'.format(expected_node_src)).body[0]
      expected_node = expected_node.value
    else:
      expected_node = gast.parse(expected_node_src).body[0]

    msg = 'AST did not match expected:\n{}\nActual:\n{}'.format(
        pretty_printer.fmt(expected_node),
        pretty_printer.fmt(actual_node))
    self.assertTrue(ast_util.matches(actual_node, expected_node), msg)
예제 #7
0
    def get_definition_directive(self, node, directive, arg, default):
        """Returns the unique directive argument for a symbol.

    See lang/directives.py for details on directives.

    Example:
       # Given a directive in the code:
       ag.foo_directive(bar, baz=1)

       # One can write for an AST node Name(id='bar'):
       get_definition_directive(node, ag.foo_directive, 'baz')

    Args:
      node: ast.AST, the node representing the symbol for which the directive
        argument is needed.
      directive: Callable[..., Any], the directive to search.
      arg: str, the directive argument to return.
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
        defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
        if not defs:
            return default

        arg_values_found = []
        for def_ in defs:
            if (directive in def_.directives
                    and arg in def_.directives[directive]):
                arg_values_found.append(def_.directives[directive][arg])

        if not arg_values_found:
            return default

        if len(arg_values_found) == 1:
            return arg_values_found[0]

        # If multiple annotations reach the symbol, they must all match. If they do,
        # return any of them.
        first_value = arg_values_found[0]
        for other_value in arg_values_found[1:]:
            if not ast_util.matches(first_value, other_value):
                qn = anno.getanno(node, anno.Basic.QN)
                raise ValueError(
                    '%s has ambiguous annotations for %s(%s): %s, %s' %
                    (qn, directive.__name__, arg,
                     compiler.ast_to_source(other_value).strip(),
                     compiler.ast_to_source(first_value).strip()))
        return first_value
예제 #8
0
  def visit_Call(self, node):
    # If the function call is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.ctx.program.options.strip_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least one positional argument.' %
              target_entity)
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if anno.hasanno(node.func, anno.Basic.QN):
        # Special-case a few builtins that otherwise go undetected. This
        # normally doesn't pose a problem, but the dict built-in doesn't
        # work with inspect.getargspec which is required for dynamic functions.
        # Note: expecting this is resilient to aliasing (e.g.
        # dict = an_evil_dict), because in those cases the regular mechanisms
        # process a simple user function.
        qn = anno.getanno(node.func, anno.Basic.QN)
        # Add items to this list as needed.
        if str(qn) in ('dict',):
          return node

      if ast_util.matches(node, 'super(_)'):
        # super() calls are preserved. The class conversion mechanism will
        # ensure that they return the correct value.
        return node

      if self.ctx.program.options.recursive:
        node = self._insert_dynamic_conversion(node)
    return node
예제 #9
0
    def visit_Call(self, node):
        # If the function call is wrapped by one of the marker decorators,
        # consider it graph ready.
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')
            if target_entity in self.ctx.program.autograph_decorators:
                if len(node.args) < 1:
                    raise ValueError(
                        'Found call to decorator function "%s", but it had no arguments. '
                        'A decorator needs at least one positional argument.' %
                        target_entity)
                anno.setanno(node.args[0], 'graph_ready', True)

        self.generic_visit(node)
        if anno.hasanno(node.func, 'live_val'):
            target_entity = anno.getanno(node.func, 'live_val')
            if anno.hasanno(node.func, 'fqn'):
                target_fqn = anno.getanno(node.func, 'fqn')
            else:
                target_fqn = None
            if self._function_is_compilable(target_entity):
                node = self._rename_compilable_function(node)
            elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
                # TODO(mdan): Should we replace these with equivalent TF ops instead?
                node = self._wrap_to_py_func_single_return(
                    node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
            else:
                raise NotImplementedError(
                    'py_func with return values (unknown function)')
        else:
            if anno.hasanno(node.func, anno.Basic.QN):
                # Special-case a few builtins that otherwise go undetected. This
                # normally doesn't pose a problem, but the dict built-in doesn't
                # work with inspect.getargspec which is required for dynamic functions.
                # Note: expecting this is resilient to aliasing (e.g.
                # dict = an_evil_dict), because in those cases the regular mechanisms
                # process a simple user function.
                qn = anno.getanno(node.func, anno.Basic.QN)
                # Add items to this list as needed.
                if str(qn) in ('dict', ):
                    return node

            if ast_util.matches(node, 'super(_)'):
                # super() calls are preserved. The class conversion mechanism will
                # ensure that they return the correct value.
                return node

            if self.ctx.program.recursive:
                node = self._insert_dynamic_conversion(node)
        return node
예제 #10
0
  def get_definition_directive(self, node, directive, arg, default):
    """Returns the unique directive argument for a symbol.

    See lang/directives.py for details on directives.

    Example:
       # Given a directive in the code:
       ag.foo_directive(bar, baz=1)

       # One can write for an AST node Name(id='bar'):
       get_definition_directive(node, ag.foo_directive, 'baz')

    Args:
      node: ast.AST, the node representing the symbol for which the directive
        argument is needed.
      directive: Callable[..., Any], the directive to search.
      arg: str, the directive argument to return.
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
    defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
    if not defs:
      return default

    arg_values_found = []
    for def_ in defs:
      if (directive in def_.directives and arg in def_.directives[directive]):
        arg_values_found.append(def_.directives[directive][arg])

    if not arg_values_found:
      return default

    if len(arg_values_found) == 1:
      return arg_values_found[0]

    # If multiple annotations reach the symbol, they must all match. If they do,
    # return any of them.
    first_value = arg_values_found[0]
    for other_value in arg_values_found[1:]:
      if not ast_util.matches(first_value, other_value):
        qn = anno.getanno(node, anno.Basic.QN)
        raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
                         (qn, directive.__name__, arg,
                          compiler.ast_to_source(other_value).strip(),
                          compiler.ast_to_source(first_value).strip()))
    return first_value
예제 #11
0
  def visit_Call(self, node):
    # If the function call is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.ctx.program.options.strip_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least one positional argument.' %
              target_entity)
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None

      if inspect_utils.isbuiltin(target_entity):
        # Note: Any builtin that passed the builtins converter is assumed to be
        # safe for graph mode.
        return node
      elif self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if ast_util.matches(node, 'super(_)'):
        # super() calls are preserved. The class conversion mechanism will
        # ensure that they return the correct value.
        return node

      if self.ctx.program.options.recursive:
        node = self._insert_dynamic_conversion(node)
    return node
예제 #12
0
    def get_definition_directive(self, node, directive, arg, default):
        """Returns the unique directive for a symbol, or a default if none exist.

    See lang/directives.py for details on directives.

    Args:
      node: ast.AST
      directive: Callable[..., Any]
      arg: str
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
        defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
        if not defs:
            return default

        # TODO(mdan): Simplify this.
        arg_values = []
        for def_ in defs:
            if (directive not in def_.directives
                    or arg not in def_.directives[directive]):
                continue
            arg_value = def_.directives[directive][arg]
            for prev_value in arg_values:
                if not ast_util.matches(arg_value, prev_value):
                    qn = anno.getanno(node, anno.Basic.QN)
                    raise ValueError(
                        '%s has ambiguous annotations for %s(%s): %s, %s' %
                        (qn, directive.__name__, arg,
                         compiler.ast_to_source(arg_value).strip(),
                         compiler.ast_to_source(prev_value).strip()))
            arg_values.append(arg_value)

        if not arg_values:
            return default

        arg_value, = arg_values
        return arg_value
예제 #13
0
  def get_definition_directive(self, node, directive, arg, default):
    """Returns the unique directive for a symbol, or a default if none exist.

    See lang/directives.py for details on directives.

    Args:
      node: ast.AST
      directive: Callable[..., Any]
      arg: str
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
    defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
    if not defs:
      return default

    # TODO(mdan): Simplify this.
    arg_values = []
    for def_ in defs:
      if (directive not in def_.directives or
          arg not in def_.directives[directive]):
        continue
      arg_value = def_.directives[directive][arg]
      for prev_value in arg_values:
        if not ast_util.matches(arg_value, prev_value):
          qn = anno.getanno(node, anno.Basic.QN)
          raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
                           (qn, directive.__name__, arg,
                            compiler.ast_to_source(arg_value).strip(),
                            compiler.ast_to_source(prev_value).strip()))
      arg_values.append(arg_value)

    if not arg_values:
      return default

    arg_value, = arg_values
    return arg_value
예제 #14
0
 def assertNoMatch(self, target_str, pattern_str):
     node = parser.parse_expression(target_str)
     pattern = parser.parse_expression(pattern_str)
     self.assertFalse(ast_util.matches(node, pattern))
예제 #15
0
 def assertNoMatch(self, target_str, pattern_str):
   node = parser.parse_expression(target_str)
   pattern = parser.parse_expression(pattern_str)
   self.assertFalse(ast_util.matches(node, pattern))