Exemplo n.º 1
0
def replace(template, **replacements):
  """Replaces placeholders in a Python template.

  AST Name and Tuple nodes always receive the context that inferred from
  the template. However, when replacing more complex nodes (that can potentially
  contain Name children), then the caller is responsible for setting the
  appropriate context.

  Args:
    template: A string representing Python code. Any symbol name can be used
        that appears in the template code can be used as placeholder.
    **replacements: A mapping from placeholder names to (lists of) AST nodes
        that these placeholders will be replaced by. String values are also
        supported as a shorthand for AST Name nodes with the respective ID.

  Returns:
    An AST node or list of AST nodes with the replacements made. If the
    template was a function, a list will be returned. If the template was a
    node, the same node will be returned. If the template was a string, an
    AST node will be returned (a `Module` node in the case of a multi-line
    string, an `Expr` node otherwise).

  Raises:
    ValueError: if the arguments are incorrect.
  """
  if not isinstance(template, str):
    raise ValueError('Expected string template, got %s' % type(template))
  tree = parser.parse_str(textwrap.dedent(template))
  for k in replacements:
    replacements[k] = _convert_to_ast(replacements[k])
  results = ReplaceTransformer(replacements).visit(tree).body
  if isinstance(results, list):
    return [qual_names.resolve(r) for r in results]
  return qual_names.resolve(results)
Exemplo n.º 2
0
def replace(template, **replacements):
    """Replaces placeholders in a Python template.

  AST Name and Tuple nodes always receive the context that inferred from
  the template. However, when replacing more complex nodes (that can potentially
  contain Name children), then the caller is responsible for setting the
  appropriate context.

  Args:
    template: A string representing Python code. Any symbol name can be used
        that appears in the template code can be used as placeholder.
    **replacements: A mapping from placeholder names to (lists of) AST nodes
        that these placeholders will be replaced by. String values are also
        supported as a shorthand for AST Name nodes with the respective ID.

  Returns:
    An AST node or list of AST nodes with the replacements made. If the
    template was a function, a list will be returned. If the template was a
    node, the same node will be returned. If the template was a string, an
    AST node will be returned (a `Module` node in the case of a multi-line
    string, an `Expr` node otherwise).

  Raises:
    ValueError: if the arguments are incorrect.
  """
    if not isinstance(template, str):
        raise ValueError('Expected string template, got %s' % type(template))
    tree = parser.parse_str(textwrap.dedent(template))
    for k in replacements:
        replacements[k] = _convert_to_ast(replacements[k])
    results = ReplaceTransformer(replacements).visit(tree).body
    if isinstance(results, list):
        return [qual_names.resolve(r) for r in results]
    return qual_names.resolve(results)
Exemplo n.º 3
0
def transform(node, ctx, default_to_null_return=True):
  """Ensure a function has only a single return, at the end."""
  node = qual_names.resolve(node)
  node = activity.resolve(node, ctx, None)

  # Note: Technically, these two could be merged into a single walk, but
  # keeping them separate helps with readability.
  node = ConditionalReturnRewriter(ctx).visit(node)

  node = qual_names.resolve(node)
  node = activity.resolve(node, ctx, None)
  transformer = ReturnStatementsTransformer(
      ctx, allow_missing_return=default_to_null_return)
  node = transformer.visit(node)
  return node
def transform(node, ctx):
    node = qual_names.resolve(node)
    node = activity.resolve(node, ctx, None)

    transformer = BreakTransformer(ctx)
    node = transformer.visit(node)
    return node
Exemplo n.º 5
0
  def prepare(self, test_fn, namespace, recursive=True):
    namespace['ConversionOptions'] = converter.ConversionOptions

    future_features = ('print_function', 'division')
    node, source = parser.parse_entity(test_fn, future_features=future_features)
    namer = naming.Namer(namespace)
    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(recursive=recursive),
        autograph_module=None)
    entity_info = transformer.EntityInfo(
        name=test_fn.__name__,
        source_code=source,
        source_file='<fragment>',
        future_features=future_features,
        namespace=namespace)
    ctx = transformer.Context(entity_info, namer, program_ctx)
    origin_info.resolve_entity(node, source, test_fn)

    graphs = cfg.build(node)
    node = qual_names.resolve(node)
    node = activity.resolve(node, ctx, None)
    node = reaching_definitions.resolve(node, ctx, graphs)
    anno.dup(
        node,
        {
            anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
        },
    )

    return node, ctx
Exemplo n.º 6
0
    def transform_ast(self, node, ctx):
        # TODO(mdan): Insert list_comprehensions somewhere.
        unsupported_features_checker.verify(node)

        # Run initial analysis.
        graphs = cfg.build(node)
        node = qual_names.resolve(node)
        node = activity.resolve(node, ctx, None)
        node = reaching_definitions.resolve(node, ctx, graphs)
        anno.dup(
            node,
            {
                anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
            },
        )

        node = functions.transform(node, ctx)
        node = directives.transform(node, ctx)
        node = break_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
            node = asserts.transform(node, ctx)
        # Note: sequencing continue canonicalization before for loop one avoids
        # dealing with the extra loop increment operation that the for
        # canonicalization creates.
        node = continue_statements.transform(node, ctx)
        node = return_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.LISTS):
            node = lists.transform(node, ctx)
            node = slices.transform(node, ctx)
        node = call_trees.transform(node, ctx)
        node = control_flow.transform(node, ctx)
        node = conditional_expressions.transform(node, ctx)
        node = logical_expressions.transform(node, ctx)
        return node
Exemplo n.º 7
0
  def test_subscript_resolve(self):
    samples = """
      x[i]
      x[i.b]
      a.b[c]
      a.b[x.y]
      a[z[c]]
      a[b[c[d]]]
      a[b].c
      a.b.c[d].e.f
      a.b[c[d]].e.f
      a.b[c[d.e.f].g].h
    """
    nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
    nodes = tuple(n.value for n in nodes.body)

    self.assertQNStringIs(nodes[0], 'x[i]')
    self.assertQNStringIs(nodes[1], 'x[i.b]')
    self.assertQNStringIs(nodes[2], 'a.b[c]')
    self.assertQNStringIs(nodes[3], 'a.b[x.y]')
    self.assertQNStringIs(nodes[4], 'a[z[c]]')
    self.assertQNStringIs(nodes[5], 'a[b[c[d]]]')
    self.assertQNStringIs(nodes[6], 'a[b].c')
    self.assertQNStringIs(nodes[7], 'a.b.c[d].e.f')
    self.assertQNStringIs(nodes[8], 'a.b[c[d]].e.f')
    self.assertQNStringIs(nodes[9], 'a.b[c[d.e.f].g].h')
Exemplo n.º 8
0
def standard_analysis(node, context, is_initial=False):
    """Performs a complete static analysis of the given code.

  Args:
    node: ast.AST
    context: converter.EntityContext
    is_initial: bool, whether this is the initial analysis done on the input
      source code

  Returns:
    ast.AST, same as node, with the static analysis annotations added
  """
    # TODO(mdan): Clear static analysis here.
    # TODO(mdan): Consider not running all analyses every time.
    # TODO(mdan): Don't return a node because it's modified by reference.
    graphs = cfg.build(node)
    node = qual_names.resolve(node)
    node = activity.resolve(node, context, None)
    node = reaching_definitions.resolve(node, context, graphs, AnnotatedDef)
    node = liveness.resolve(node, context, graphs)
    node = live_values.resolve(node, context, config.PYTHON_LITERALS)
    node = type_info.resolve(node, context)
    # This second call allows resolving first-order class attributes.
    node = live_values.resolve(node, context, config.PYTHON_LITERALS)
    if is_initial:
        anno.dup(
            node,
            {
                anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
            },
        )
    return node
Exemplo n.º 9
0
    def test_subscript_resolve(self):
        samples = """
      x[i]
      x[i.b]
      a.b[c]
      a.b[x.y]
      a[z[c]]
      a[b[c[d]]]
      a[b].c
      a.b.c[d].e.f
      a.b[c[d]].e.f
      a.b[c[d.e.f].g].h
    """
        nodes = parser.parse(textwrap.dedent(samples), single_node=False)
        nodes = tuple(resolve(node).value for node in nodes)

        self.assertQNStringIs(nodes[0], 'x[i]')
        self.assertQNStringIs(nodes[1], 'x[i.b]')
        self.assertQNStringIs(nodes[2], 'a.b[c]')
        self.assertQNStringIs(nodes[3], 'a.b[x.y]')
        self.assertQNStringIs(nodes[4], 'a[z[c]]')
        self.assertQNStringIs(nodes[5], 'a[b[c[d]]]')
        self.assertQNStringIs(nodes[6], 'a[b].c')
        self.assertQNStringIs(nodes[7], 'a.b.c[d].e.f')
        self.assertQNStringIs(nodes[8], 'a.b[c[d]].e.f')
        self.assertQNStringIs(nodes[9], 'a.b[c[d.e.f].g].h')
Exemplo n.º 10
0
def standard_analysis(node, context, is_initial=False):
  """Performs a complete static analysis of the given code.

  Args:
    node: ast.AST
    context: converter.EntityContext
    is_initial: bool, whether this is the initial analysis done on the input
      source code

  Returns:
    ast.AST, same as node, with the static analysis annotations added
  """
  # TODO(mdan): Clear static analysis here.
  # TODO(mdan): Consider not running all analyses every time.
  # TODO(mdan): Don't return a node because it's modified by reference.
  graphs = cfg.build(node)
  node = qual_names.resolve(node)
  node = activity.resolve(node, context.info, None)
  node = reaching_definitions.resolve(node, context.info, graphs, AnnotatedDef)
  node = liveness.resolve(node, context.info, graphs)
  node = live_values.resolve(node, context.info, config.PYTHON_LITERALS)
  node = type_info.resolve(node, context.info)
  # This second call allows resolving first-order class attributes.
  node = live_values.resolve(node, context.info, config.PYTHON_LITERALS)
  if is_initial:
    anno.dup(
        node,
        {
            anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
        },
    )
  return node
Exemplo n.º 11
0
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn, future_features=())
   entity_info = transformer.EntityInfo(
       source_code=source, source_file=None, future_features=(), namespace={})
   node = qual_names.resolve(node)
   ctx = transformer.Context(entity_info)
   node = activity.resolve(node, ctx)
   return node, entity_info
Exemplo n.º 12
0
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn, future_features=())
   entity_info = transformer.EntityInfo(
       source_code=source, source_file=None, future_features=(), namespace={})
   node = qual_names.resolve(node)
   ctx = transformer.Context(entity_info)
   node = activity.resolve(node, ctx)
   return node, entity_info
Exemplo n.º 13
0
 def transform_ast(self, node, ctx):
     node = qual_names.resolve(node)
     node = activity.resolve(node, ctx)
     graphs = cfg.build(node)
     node = reaching_definitions.resolve(node, ctx, graphs)
     node = reaching_fndefs.resolve(node, ctx, graphs)
     node = type_inference.resolve(node, ctx, graphs, self.resolver)
     return node
Exemplo n.º 14
0
def transform(node, ctx):
  graphs = cfg.build(node)
  node = qual_names.resolve(node)
  node = activity.resolve(node, ctx, None)
  node = reaching_definitions.resolve(node, ctx, graphs, AnnotatedDef)
  node = liveness.resolve(node, ctx, graphs)

  node = ControlFlowTransformer(ctx).visit(node)
  return node
Exemplo n.º 15
0
    def test_rename_symbols_attributes(self):
        node = parser.parse('b.c = b.c.d')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
Exemplo n.º 16
0
    def test_rename_symbols_global(self):
        node = parser.parse('global a, b, c')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b'): qual_names.QN('renamed_b')})

        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'global a, renamed_b, c')
Exemplo n.º 17
0
    def test_rename_symbols_attributes(self):
        node = parser.parse_str('b.c = b.c.d')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

        source = compiler.ast_to_source(node)
        self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
Exemplo n.º 18
0
  def test_rename_symbols_attributes(self):
    node = parser.parse_str('b.c = b.c.d')
    node = qual_names.resolve(node)

    node = ast_util.rename_symbols(
        node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})

    source = compiler.ast_to_source(node)
    self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
Exemplo n.º 19
0
    def test_rename_symbols_annotations(self):
        node = parser.parse_str('a[i]')
        node = qual_names.resolve(node)
        anno.setanno(node, 'foo', 'bar')
        orig_anno = anno.getanno(node, 'foo')

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('b')})

        self.assertIs(anno.getanno(node, 'foo'), orig_anno)
Exemplo n.º 20
0
  def test_rename_symbols_annotations(self):
    node = parser.parse_str('a[i]')
    node = qual_names.resolve(node)
    anno.setanno(node, 'foo', 'bar')
    orig_anno = anno.getanno(node, 'foo')

    node = ast_util.rename_symbols(node,
                                   {qual_names.QN('a'): qual_names.QN('b')})

    self.assertIs(anno.getanno(node, 'foo'), orig_anno)
Exemplo n.º 21
0
    def test_rename_symbols_basic(self):
        node = parser.parse_str('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

        self.assertIsInstance(node.body[0].value.left.id, str)
        source = compiler.ast_to_source(node)
        self.assertEqual(source.strip(), 'renamed_a + b')
Exemplo n.º 22
0
def convert2():
    node, ctx = get_node_and_ctx(f2)

    node = qual_names.resolve(node)
    node = activity.resolve(node, ctx)

    fn_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)  # Note: tag will be changed soon.

    print('read:', fn_scope.read)
    print('modified:', fn_scope.modified)
Exemplo n.º 23
0
    def test_rename_symbols_basic(self):
        node = parser.parse('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

        self.assertIsInstance(node.value.left.id, str)
        source = parser.unparse(node, include_encoding_marker=False)
        self.assertEqual(source.strip(), 'renamed_a + b')
Exemplo n.º 24
0
  def test_rename_symbols_basic(self):
    node = parser.parse_str('a + b')
    node = qual_names.resolve(node)

    node = ast_util.rename_symbols(
        node, {qual_names.QN('a'): qual_names.QN('renamed_a')})

    self.assertIsInstance(node.body[0].value.left.id, str)
    source = compiler.ast_to_source(node)
    self.assertEqual(source.strip(), 'renamed_a + b')
Exemplo n.º 25
0
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn, future_features=())
   entity_info = transformer.EntityInfo(
       source_code=source, source_file=None, future_features=(), namespace={})
   node = qual_names.resolve(node)
   ctx = transformer.Context(entity_info)
   node = activity.resolve(node, ctx)
   graphs = cfg.build(node)
   node = reaching_definitions.resolve(node, ctx, graphs,
                                       reaching_definitions.Definition)
   return node
Exemplo n.º 26
0
 def _parse_and_analyze(self, test_fn):
     node, source, _ = parser.parse_entity(test_fn)
     entity_info = transformer.EntityInfo(source_code=source,
                                          source_file=None,
                                          namespace={},
                                          arg_values=None,
                                          arg_types=None)
     node = qual_names.resolve(node)
     ctx = transformer.Context(entity_info)
     node = activity.resolve(node, ctx)
     return node, entity_info
Exemplo n.º 27
0
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn, future_features=())
   entity_info = transformer.EntityInfo(
       source_code=source, source_file=None, future_features=(), namespace={})
   node = qual_names.resolve(node)
   ctx = transformer.Context(entity_info)
   node = activity.resolve(node, ctx)
   graphs = cfg.build(node)
   node = reaching_definitions.resolve(node, ctx, graphs,
                                       reaching_definitions.Definition)
   return node
Exemplo n.º 28
0
 def initial_analysis(self, node, ctx):
     graphs = cfg.build(node)
     node = qual_names.resolve(node)
     node = activity.resolve(node, ctx, None)
     node = reaching_definitions.resolve(node, ctx, graphs)
     anno.dup(
         node,
         {
             anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
         },
     )
     return node
Exemplo n.º 29
0
    def test_rename_symbols_basic(self):
        node = parser.parse('a + b')
        node = qual_names.resolve(node)

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
        source = parser.unparse(node, include_encoding_marker=False)
        expected_node_src = 'renamed_a + b'

        self.assertIsInstance(node.value.left.id, str)
        self.assertAstMatches(node, source)
        self.assertAstMatches(node, expected_node_src)
Exemplo n.º 30
0
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file=None,
       namespace={},
       arg_values=None,
       arg_types=None,
       owner_type=None)
   node = qual_names.resolve(node)
   node = activity.resolve(node, entity_info)
   return node, entity_info
Exemplo n.º 31
0
def convert4():
    node, ctx = get_node_and_ctx(f4)

    node = qual_names.resolve(node)
    cfgs = cfg.build(node)
    node = activity.resolve(node, ctx)
    node = reaching_definitions.resolve(node, ctx, cfgs)
    node = reaching_fndefs.resolve(node, ctx, cfgs)
    node = liveness.resolve(node, ctx, cfgs)

    print('live into `b = a + 1`:', anno.getanno(node.body[0], anno.Static.LIVE_VARS_IN))
    print('live into `return b`:', anno.getanno(node.body[1], anno.Static.LIVE_VARS_IN))
Exemplo n.º 32
0
 def _parse_and_analyze(self, test_fn):
     node, source = parser.parse_entity(test_fn)
     entity_info = transformer.EntityInfo(source_code=source,
                                          source_file=None,
                                          namespace={},
                                          arg_values=None,
                                          arg_types=None,
                                          owner_type=None)
     node = qual_names.resolve(node)
     node = activity.resolve(node, entity_info)
     graphs = cfg.build(node)
     liveness.resolve(node, entity_info, graphs)
     return node
Exemplo n.º 33
0
 def _parse_and_analyze(self, test_fn):
     node, _, source = parser.parse_entity(test_fn, future_imports=())
     entity_info = transformer.EntityInfo(source_code=source,
                                          source_file=None,
                                          namespace={},
                                          arg_values=None,
                                          arg_types=None)
     node = qual_names.resolve(node)
     ctx = transformer.Context(entity_info)
     node = activity.resolve(node, ctx)
     graphs = cfg.build(node)
     liveness.resolve(node, ctx, graphs)
     return node
Exemplo n.º 34
0
def mlir_gen_internal(node, entity_info):
  """Returns mlir module for unprocessed node `node`."""
  namer = naming.Namer({})
  graphs = cfg.build(node)
  ctx = transformer.Context(entity_info, namer, None)
  node = qual_names.resolve(node)
  node = activity.resolve(node, ctx)
  node = reaching_definitions.resolve(node, ctx, graphs)
  node = reaching_fndefs.resolve(node, ctx, graphs)
  node = liveness.resolve(node, ctx, graphs)
  mlir_generator = MLIRGen(ctx)
  mlir_generator.visit(node)
  return mlir_generator.prog
Exemplo n.º 35
0
 def _parse_and_analyze(self, test_fn):
     # TODO(mdan): Use a custom FunctionTransformer here.
     node, source = parser.parse_entity(test_fn, future_features=())
     entity_info = transformer.EntityInfo(name=test_fn.__name__,
                                          source_code=source,
                                          source_file=None,
                                          future_features=(),
                                          namespace={})
     node = qual_names.resolve(node)
     namer = naming.Namer({})
     ctx = transformer.Context(entity_info, namer, None)
     node = activity.resolve(node, ctx)
     return node, entity_info
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file=None,
       namespace={},
       arg_values=None,
       arg_types=None,
       owner_type=None)
   node = qual_names.resolve(node)
   node = activity.resolve(node, entity_info)
   graphs = cfg.build(node)
   node = reaching_definitions.resolve(node, entity_info, graphs,
                                       reaching_definitions.Definition)
   return node
Exemplo n.º 37
0
def transform(node, ctx):
    """Transform function call to the compiled counterparts.

  Args:
    node: AST
    ctx: EntityContext
  Returns:
    A tuple (node, new_names):
        node: The transformed AST
        new_names: set(string), containing any newly-generated names
  """
    node = qual_names.resolve(node)

    node = CallTreeTransformer(ctx).visit(node)
    return node
Exemplo n.º 38
0
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file=None,
       namespace={},
       arg_values=None,
       arg_types=None,
       owner_type=None)
   node = qual_names.resolve(node)
   ctx = transformer.Context(entity_info)
   node = activity.resolve(node, ctx)
   graphs = cfg.build(node)
   liveness.resolve(node, ctx, graphs)
   return node
Exemplo n.º 39
0
  def transform_ast(self, node, ctx):
    node = _apply_py_to_tf_passes(node, ctx)
    # TODO(mdan): Enable this.
    # node = anf.transform(node, ctx)

    graphs = cfg.build(node)
    node = qual_names.resolve(node)
    node = activity.resolve(node, ctx)
    node = reaching_definitions.resolve(node, ctx, graphs)
    node = reaching_fndefs.resolve(node, ctx, graphs)
    node = type_inference.resolve(node, ctx, graphs,
                                  TFRTypeResolver(self._op_defs))

    mlir_generator = TFRGen(ctx, self._op_defs)
    mlir_generator.visit(node)
    return mlir_generator.code_buffer
Exemplo n.º 40
0
 def test_function_calls(self):
   samples = """
     a.b
     a.b()
     a().b
     z[i]
     z[i]()
     z()[i]
   """
   nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
   nodes = tuple(n.value for n in nodes.body)
   self.assertQNStringIs(nodes[0], 'a.b')
   self.assertQNStringIs(nodes[1].func, 'a.b')
   self.assertQNStringIs(nodes[2].value.func, 'a')
   self.assertQNStringIs(nodes[3], 'z[i]')
   self.assertQNStringIs(nodes[4].func, 'z[i]')
   self.assertQNStringIs(nodes[5].value.func, 'z')
Exemplo n.º 41
0
 def test_function_calls(self):
     samples = """
   a.b
   a.b()
   a().b
   z[i]
   z[i]()
   z()[i]
 """
     nodes = parser.parse(textwrap.dedent(samples), single_node=False)
     nodes = tuple(resolve(node).value for node in nodes)
     self.assertQNStringIs(nodes[0], 'a.b')
     self.assertQNStringIs(nodes[1].func, 'a.b')
     self.assertQNStringIs(nodes[2].value.func, 'a')
     self.assertQNStringIs(nodes[3], 'z[i]')
     self.assertQNStringIs(nodes[4].func, 'z[i]')
     self.assertQNStringIs(nodes[5].value.func, 'z')
Exemplo n.º 42
0
def replace_as_expression(template, **replacements):
  """Variant of replace that generates expressions, instead of code blocks."""
  replacement = replace(template, **replacements)
  if len(replacement) != 1:
    raise ValueError(
        'single expression expected; for more general templates use replace')
  node = replacement[0]
  node = qual_names.resolve(node)

  if isinstance(node, gast.Expr):
    return node.value
  elif isinstance(node, gast.Name):
    return node

  raise ValueError(
      'the template is expected to generate an expression or a name node;'
      ' instead found %s' % node)
Exemplo n.º 43
0
def replace_as_expression(template, **replacements):
  """Variant of replace that generates expressions, instead of code blocks."""
  replacement = replace(template, **replacements)
  if len(replacement) != 1:
    raise ValueError(
        'single expression expected; for more general templates use replace')
  node = replacement[0]
  node = qual_names.resolve(node)

  if isinstance(node, gast.Expr):
    return node.value
  elif isinstance(node, gast.Name):
    return node

  raise ValueError(
      'the template is expected to generate an expression or a name node;'
      ' instead found %s' % node)
Exemplo n.º 44
0
  def test_resolve(self):
    samples = """
      a
      a.b
      (c, d.e)
      [f, (g.h.i)]
      j(k, l)
    """
    nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
    nodes = tuple(n.value for n in nodes.body)

    self.assertQNStringIs(nodes[0], 'a')
    self.assertQNStringIs(nodes[1], 'a.b')
    self.assertQNStringIs(nodes[2].elts[0], 'c')
    self.assertQNStringIs(nodes[2].elts[1], 'd.e')
    self.assertQNStringIs(nodes[3].elts[0], 'f')
    self.assertQNStringIs(nodes[3].elts[1], 'g.h.i')
    self.assertQNStringIs(nodes[4].func, 'j')
    self.assertQNStringIs(nodes[4].args[0], 'k')
    self.assertQNStringIs(nodes[4].args[1], 'l')
Exemplo n.º 45
0
 def _parse_and_analyze(self,
                        test_fn,
                        namespace,
                        arg_types=None):
   node, source = parser.parse_entity(test_fn)
   entity_info = transformer.EntityInfo(
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       owner_type=None)
   node = qual_names.resolve(node)
   graphs = cfg.build(node)
   ctx = transformer.Context(entity_info)
   node = activity.resolve(node, ctx)
   node = reaching_definitions.resolve(node, ctx, graphs,
                                       reaching_definitions.Definition)
   node = live_values.resolve(node, ctx, {})
   node = type_info.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   return node
Exemplo n.º 46
0
def class_to_graph(c, program_ctx):
  """Specialization of `entity_to_graph` for classes."""
  converted_members = {}
  method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
  members = tf_inspect.getmembers(c, predicate=method_filter)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.' % c)

  class_namespace = {}
  for _, m in members:
    # Only convert the members that are directly defined by the class.
    if inspect_utils.getdefiningclass(m, c) is not c:
      continue
    node, _, namespace = function_to_graph(
        m,
        program_ctx=program_ctx,
        arg_values={},
        arg_types={'self': (c.__name__, c)},
        owner_type=c)
    if class_namespace is None:
      class_namespace = namespace
    else:
      class_namespace.update(namespace)
    converted_members[m] = node[0]
  namer = program_ctx.new_namer(class_namespace)
  class_name = namer.compiled_class_name(c.__name__, c)

  # TODO(mdan): This needs to be explained more thoroughly.
  # Process any base classes: if the superclass if of a whitelisted type, an
  # absolute import line is generated. Otherwise, it is marked for conversion
  # (as a side effect of the call to namer.compiled_class_name() followed by
  # program_ctx.update_name_map(namer)).
  output_nodes = []
  renames = {}
  base_names = []
  for base in c.__bases__:
    if isinstance(object, base):
      base_names.append('object')
      continue
    if is_whitelisted_for_graph(base):
      alias = namer.new_symbol(base.__name__, ())
      output_nodes.append(
          gast.ImportFrom(
              module=base.__module__,
              names=[gast.alias(name=base.__name__, asname=alias)],
              level=0))
    else:
      # This will trigger a conversion into a class with this name.
      alias = namer.compiled_class_name(base.__name__, base)
    base_names.append(alias)
    renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
  program_ctx.update_name_map(namer)

  # Generate the definition of the converted class.
  bases = [gast.Name(n, gast.Load(), None) for n in base_names]
  class_def = gast.ClassDef(
      class_name,
      bases=bases,
      keywords=[],
      body=list(converted_members.values()),
      decorator_list=[])
  # Make a final pass to replace references to the class or its base classes.
  # Most commonly, this occurs when making super().__init__() calls.
  # TODO(mdan): Making direct references to superclass' superclass will fail.
  class_def = qual_names.resolve(class_def)
  renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
  class_def = ast_util.rename_symbols(class_def, renames)

  output_nodes.append(class_def)

  return output_nodes, class_name, class_namespace
Exemplo n.º 47
0
def convert_class_to_ast(c, program_ctx):
  """Specialization of `convert_entity_to_ast` for classes."""
  # TODO(mdan): Revisit this altogether. Not sure we still need it.
  converted_members = {}
  method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
  members = tf_inspect.getmembers(c, predicate=method_filter)
  if not members:
    raise ValueError('cannot convert %s: no member methods' % c)

  # TODO(mdan): Don't clobber namespaces for each method in one class namespace.
  # The assumption that one namespace suffices for all methods only holds if
  # all methods were defined in the same module.
  # If, instead, functions are imported from multiple modules and then spliced
  # into the class, then each function has its own globals and __future__
  # imports that need to stay separate.

  # For example, C's methods could both have `global x` statements referring to
  # mod1.x and mod2.x, but using one namespace for C would cause a conflict.
  # from mod1 import f1
  # from mod2 import f2
  # class C(object):
  #   method1 = f1
  #   method2 = f2

  class_namespace = {}
  future_features = None
  for _, m in members:
    # Only convert the members that are directly defined by the class.
    if inspect_utils.getdefiningclass(m, c) is not c:
      continue
    (node,), _, entity_info = convert_func_to_ast(
        m,
        program_ctx=program_ctx,
        do_rename=False)
    class_namespace.update(entity_info.namespace)
    converted_members[m] = node

    # TODO(mdan): Similarly check the globals.
    if future_features is None:
      future_features = entity_info.future_features
    elif frozenset(future_features) ^ frozenset(entity_info.future_features):
      # Note: we can support this case if ever needed.
      raise ValueError(
          'cannot convert {}: if has methods built with mismatched future'
          ' features: {} and {}'.format(c, future_features,
                                        entity_info.future_features))
  namer = naming.Namer(class_namespace)
  class_name = namer.class_name(c.__name__)

  # Process any base classes: if the superclass if of a whitelisted type, an
  # absolute import line is generated.
  output_nodes = []
  renames = {}
  base_names = []
  for base in c.__bases__:
    if isinstance(object, base):
      base_names.append('object')
      continue
    if is_whitelisted_for_graph(base):
      alias = namer.new_symbol(base.__name__, ())
      output_nodes.append(
          gast.ImportFrom(
              module=base.__module__,
              names=[gast.alias(name=base.__name__, asname=alias)],
              level=0))
    else:
      raise NotImplementedError(
          'Conversion of classes that do not directly extend classes from'
          ' whitelisted modules is temporarily suspended. If this breaks'
          ' existing code please notify the AutoGraph team immediately.')
    base_names.append(alias)
    renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

  # Generate the definition of the converted class.
  bases = [gast.Name(n, gast.Load(), None) for n in base_names]
  class_def = gast.ClassDef(
      class_name,
      bases=bases,
      keywords=[],
      body=list(converted_members.values()),
      decorator_list=[])
  # Make a final pass to replace references to the class or its base classes.
  # Most commonly, this occurs when making super().__init__() calls.
  # TODO(mdan): Making direct references to superclass' superclass will fail.
  class_def = qual_names.resolve(class_def)
  renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
  class_def = ast_util.rename_symbols(class_def, renames)

  output_nodes.append(class_def)

  # TODO(mdan): Find a way better than forging this object.
  entity_info = transformer.EntityInfo(
      source_code=None,
      source_file=None,
      future_features=future_features,
      namespace=class_namespace)

  return output_nodes, class_name, entity_info