Beispiel #1
0
 def assert_same_ast(self, expected_node, node, msg=None):
     expected_source = compiler.ast_to_source(expected_node,
                                              indentation='  ')
     expected_str = textwrap.dedent(expected_source).strip()
     got_source = compiler.ast_to_source(node, indentation='  ')
     got_str = textwrap.dedent(got_source).strip()
     self.assertEqual(expected_str, got_str, msg=msg)
Beispiel #2
0
 def __repr__(self):
   if isinstance(self.ast_node, gast.FunctionDef):
     return 'def %s' % self.ast_node.name
   elif isinstance(self.ast_node, gast.withitem):
     source, _ = compiler.ast_to_source(self.ast_node.context_expr)
     return source.strip()
   source, _ = compiler.ast_to_source(self.ast_node)
   return source.strip()
Beispiel #3
0
 def __repr__(self):
     if isinstance(self.ast_node, gast.FunctionDef):
         return 'def %s' % self.ast_node.name
     elif isinstance(self.ast_node, gast.withitem):
         source, _ = compiler.ast_to_source(self.ast_node.context_expr)
         return source.strip()
     source, _ = compiler.ast_to_source(self.ast_node)
     return source.strip()
Beispiel #4
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

  code = '\n'.join(
      compiler.ast_to_source(dep, indentation)
      for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))

  return program_ctx.required_imports + '\n\n' + code
Beispiel #5
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
  conversion_map = conversion.ConversionMap(
      recursive=recursive,
      nocompile_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      api_module=tf_inspect.getmodule(to_graph))
  conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

  imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
  code = '\n'.join(
      compiler.ast_to_source(dep, indentation)
      for dep in reversed(tuple(
          six.itervalues(conversion_map.dependency_cache))))

  return imports + '\n\n' + code
Beispiel #6
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
    """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, do_not_convert, converted_call),
        partial_types=partial_types,
        api_module=tf_inspect.getmodule(to_graph))
    conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

    imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
    code = '\n'.join(
        compiler.ast_to_source(dep, indentation) for dep in reversed(
            tuple(six.itervalues(conversion_map.dependency_cache))))

    return imports + '\n\n' + code
Beispiel #7
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

  code = '\n'.join(
      compiler.ast_to_source(dep, indentation)
      for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))

  return program_ctx.required_imports + '\n\n' + code
    def test_rename_symbols_attributes(self):
        node = parser.parse_str('b.c = b.c.d')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

        source, _ = compiler.ast_to_source(node)
        self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
  def test_rename_symbols_attributes(self):
    node = parser.parse_str('b.c = b.c.d')
    node = qual_names.resolve(node)

    node = ast_util.rename_symbols(
        node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

    source = compiler.ast_to_source(node)
    self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
    def test_rename_symbols_basic(self):
        node = parser.parse_str('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

        self.assertIsInstance(node.body[0].value.left.id, str)
        self.assertEqual(compiler.ast_to_source(node).strip(), 'renamed_a + b')
Beispiel #11
0
 def _get_source(self, node):
     try:
         source, _ = compiler.ast_to_source(node)
         return source
     # pylint: disable=broad-except
     # This function is used for error reporting.  If an exception occurs here,
     # it should be suppressed, in favor of emitting as informative a message
     # about the original error as possible.
     except Exception:
         return '<could not convert AST to source>'
Beispiel #12
0
  def test_rename_symbols_basic(self):
    node = parser.parse_str('a + b')
    node = qual_names.resolve(node)

    node = ast_util.rename_symbols(
        node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

    self.assertIsInstance(node.body[0].value.left.id, str)
    source = compiler.ast_to_source(node)
    self.assertEqual(source.strip(), 'renamed_a + b')
Beispiel #13
0
 def _get_source(self, node):
   try:
     source, _ = compiler.ast_to_source(node)
     return source
   # pylint: disable=broad-except
   # This function is used for error reporting.  If an exception occurs here,
   # it should be suppressed, in favor of emitting as informative a message
   # about the original error as possible.
   except Exception:
     return '<could not convert AST to source>'
    def get_definition_directive(self, node, directive, arg, default):
        """Returns the unique directive for a symbol, or a default if none exist.

    See lang/directives.py for details on directives.

    Args:
      node: ast.AST
      directive: Callable[..., Any]
      arg: str
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
        defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
        if not defs:
            return default

        # TODO(mdan): Simplify this.
        arg_values = []
        for def_ in defs:
            if (directive not in def_.directives
                    or arg not in def_.directives[directive]):
                continue
            arg_value = def_.directives[directive][arg]
            for prev_value in arg_values:
                if not ast_util.matches(arg_value, prev_value):
                    qn = anno.getanno(node, anno.Basic.QN)
                    raise ValueError(
                        '%s has ambiguous annotations for %s(%s): %s, %s' %
                        (qn, directive.__name__, arg,
                         compiler.ast_to_source(arg_value).strip(),
                         compiler.ast_to_source(prev_value).strip()))
            arg_values.append(arg_value)

        if not arg_values:
            return default

        arg_value, = arg_values
        return arg_value
Beispiel #15
0
  def get_definition_directive(self, node, directive, arg, default):
    """Returns the unique directive for a symbol, or a default if none exist.

    See lang/directives.py for details on directives.

    Args:
      node: ast.AST
      directive: Callable[..., Any]
      arg: str
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
    defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
    if not defs:
      return default

    # TODO(mdan): Simplify this.
    arg_values = []
    for def_ in defs:
      if (directive not in def_.directives or
          arg not in def_.directives[directive]):
        continue
      arg_value = def_.directives[directive][arg]
      for prev_value in arg_values:
        if not ast_util.matches(arg_value, prev_value):
          qn = anno.getanno(node, anno.Basic.QN)
          raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
                           (qn, directive.__name__, arg,
                            compiler.ast_to_source(arg_value).strip(),
                            compiler.ast_to_source(prev_value).strip()))
      arg_values.append(arg_value)

    if not arg_values:
      return default

    arg_value, = arg_values
    return arg_value
  def test_source_map(self):

    def test_fn(x):
      if x > 0:
        x += 1
      return x

    node, source = parser.parse_entity(test_fn)
    fn_node = node.body[0]
    origin_info.resolve(fn_node, source)

    # Insert a traced line.
    new_node = parser.parse_str('x = abs(x)').body[0]
    anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
    fn_node.body.insert(0, new_node)

    # Insert an untraced line.
    fn_node.body.insert(0, parser.parse_str('x = 0').body[0])

    modified_source = compiler.ast_to_source(fn_node)

    source_map = origin_info.source_map(fn_node, modified_source,
                                        'test_filename', [0])

    loc = origin_info.LineLocation('test_filename', 1)
    origin = source_map[loc]
    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
    self.assertEqual(origin.loc.lineno, 1)

    # The untraced line, inserted second.
    loc = origin_info.LineLocation('test_filename', 2)
    self.assertFalse(loc in source_map)

    # The traced line, inserted first.
    loc = origin_info.LineLocation('test_filename', 3)
    origin = source_map[loc]
    self.assertEqual(origin.source_code_line, '  if x > 0:')
    self.assertEqual(origin.loc.lineno, 2)

    loc = origin_info.LineLocation('test_filename', 4)
    origin = source_map[loc]
    self.assertEqual(origin.source_code_line, '  if x > 0:')
    self.assertEqual(origin.loc.lineno, 2)
Beispiel #17
0
    def test_ast_to_source(self):
        node = gast.If(
            test=gast.Num(1),
            body=[
                gast.Assign(targets=[gast.Name('a', gast.Store(), None)],
                            value=gast.Name('b', gast.Load(), None))
            ],
            orelse=[
                gast.Assign(targets=[gast.Name('a', gast.Store(), None)],
                            value=gast.Str('c'))
            ])

        self.assertEqual(
            textwrap.dedent("""
            if 1:
              a = b
            else:
              a = 'c'
        """).strip(),
            compiler.ast_to_source(node, indentation='  ').strip())
Beispiel #18
0
  def test_ast_to_source(self):
    node = gast.If(
        test=gast.Num(1),
        body=[
            gast.Assign(
                targets=[gast.Name('a', gast.Store(), None)],
                value=gast.Name('b', gast.Load(), None))
        ],
        orelse=[
            gast.Assign(
                targets=[gast.Name('a', gast.Store(), None)],
                value=gast.Str('c'))
        ])

    source = compiler.ast_to_source(node, indentation='  ')
    self.assertEqual(
        textwrap.dedent("""
            if 1:
              a = b
            else:
              a = 'c'
        """).strip(), source.strip())
Beispiel #19
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
  """Returns the equivalent code that uses TensorFlow ops.

  Also see: `to_graph`, `convert`

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
        converted function may call.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
        function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
        function arguments.
    partial_types: Set[Type], reserved for internal use.
    indentation: Text, when to use for each level of indentation.

  Returns:
    Text, the converted code.
  """
  program_ctx = converter.ProgramContext(
      recursive=recursive,
      autograph_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      autograph_module=tf_inspect.getmodule(to_graph),
      uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
  conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

  code = '\n'.join(
      compiler.ast_to_source(dep, indentation)
      for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))

  return program_ctx.required_imports + '\n\n' + code
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
    """Returns the equivalent code that uses TensorFlow ops.

  Also see: `to_graph`, `convert`

  Args:
    e: Union[Callable, Type], the Python entity to convert.
    recursive: bool, whether to recursively convert any functions that the
        converted function may call.
    arg_values: Optional[Dict[Text, Any]], value hints for symbols including
        function arguments.
    arg_types: Optional[Dict[Text, Type]], type hints for symbols including
        function arguments.
    partial_types: Set[Type], reserved for internal use.
    indentation: Text, when to use for each level of indentation.

  Returns:
    Text, the converted code.
  """
    program_ctx = converter.ProgramContext(
        recursive=recursive,
        autograph_decorators=(convert, do_not_convert, converted_call),
        partial_types=partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)

    code = '\n'.join(
        compiler.ast_to_source(dep, indentation) for dep in reversed(
            tuple(six.itervalues(program_ctx.dependency_cache))))

    return program_ctx.required_imports + '\n\n' + code
 def _mock_apply_fn(self, target, source):
     target = compiler.ast_to_source(target).strip()
     source = compiler.ast_to_source(source).strip()
     self._invocation_counts[(target, source)] += 1
Beispiel #22
0
def try_ast_to_source(node):
  try:
    return compiler.ast_to_source(node)
  except AssertionError:
    return '<could not convert AST to source>'
 def _mock_apply_fn(self, target, source):
     target, _ = compiler.ast_to_source(target)
     source, _ = compiler.ast_to_source(source)
     self._invocation_counts[(target.strip(), source.strip())] += 1
Beispiel #24
0
 def _mock_apply_fn(self, target, source):
   target = compiler.ast_to_source(target)
   source = compiler.ast_to_source(source)
   self._invocation_counts[(target.strip(), source.strip())] += 1
Beispiel #25
0
 def assert_same_ast(self, expected_node, node, msg=None):
   expected_source = compiler.ast_to_source(expected_node, indentation='  ')
   expected_str = textwrap.dedent(expected_source).strip()
   got_source = compiler.ast_to_source(node, indentation='  ')
   got_str = textwrap.dedent(got_source).strip()
   self.assertEqual(expected_str, got_str, msg=msg)
Beispiel #26
0
 def __repr__(self):
     return compiler.ast_to_source(self.ast_node).strip()
Beispiel #27
0
 def _get_source(self, node):
   try:
     source, _ = compiler.ast_to_source(node)
     return source
   except AssertionError:
     return '<could not convert AST to source>'
Beispiel #28
0
 def __repr__(self):
   return compiler.ast_to_source(self.ast_node).strip()