def parse_and_analyze(self,
                       test_fn,
                       namespace,
                       namer=None,
                       arg_types=None,
                       include_type_analysis=True,
                       owner_type=None,
                       recursive=True):
   node, source = parser.parse_entity(test_fn)
   ctx = context.EntityContext(
       namer=namer or FakeNamer(),
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       owner_type=owner_type,
       recursive=recursive)
   node = qual_names.resolve(node)
   node = activity.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   if include_type_analysis:
     node = type_info.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
   self.ctx = ctx
   return node
Exemple #2
0
 def parse_and_analyze(self,
                       test_fn,
                       namespace,
                       namer=None,
                       arg_types=None,
                       include_type_analysis=True,
                       owner_type=None,
                       recursive=True):
   node, source = parser.parse_entity(test_fn)
   ctx = context.EntityContext(
       namer=namer or FakeNamer(),
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       owner_type=owner_type,
       recursive=recursive,
       type_annotation_func=utils.set_element_type)
   node = qual_names.resolve(node)
   node = activity.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   if include_type_analysis:
     node = type_info.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
   self.ctx = ctx
   return node
Exemple #3
0
 def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
   node, source = parser.parse_entity(test_fn)
   ctx = context.EntityContext(
       namer=None,
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       recursive=True)
   node = access.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   node = type_info.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   return node
 def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
     node, source = parser.parse_entity(test_fn)
     ctx = context.EntityContext(namer=None,
                                 source_code=source,
                                 source_file=None,
                                 namespace=namespace,
                                 arg_values=None,
                                 arg_types=arg_types,
                                 recursive=True)
     node = qual_names.resolve(node)
     node = activity.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
     node = type_info.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
     return node
Exemple #5
0
def node_to_graph(node, namer, namespace, value_hints):
    """Convert Python code to equivalent TF graph mode code.

  Args:
    node: A Python AST node representing the code to convert.
    namer: A naming.Namer object.
    namespace: Dict mapping symbol names to their corresponding live objects.
    value_hints: A dict containing value hints for symbols like function
        parameters.

  Returns:
    A tuple (node, deps):
        * node: A Python ast node, representing the converted code.
        * deps: A set of strings, the fully qualified names of object
            dependencies that this node has.
  """
    node = access.resolve(node)
    node = live_values.resolve(node, namespace, config.PYTHON_LITERALS)
    node = type_info.resolve(node, value_hints)

    # TODO(mdan): Factor out common elements.
    # These include:
    #   * keeping track of symbols that have been created
    #   * marking nodes (e.g. py_func wrappers) to suppress further processing

    node = print_functions.transform(node)
    node = call_trees.transform(node, namer, config.DEFAULT_UNCOMPILED_MODULES)
    node = control_flow.transform(node, namer)
    node = logical_expressions.transform(node)
    node = side_effect_guards.transform(node, namer)

    return node
Exemple #6
0
def node_to_graph(node, namer, namespace, value_hints):
  """Convert Python code to equivalent TF graph mode code.

  Args:
    node: A Python AST node representing the code to convert.
    namer: A naming.Namer object.
    namespace: Dict mapping symbol names to their corresponding live objects.
    value_hints: A dict containing value hints for symbols like function
        parameters.

  Returns:
    A tuple (node, deps):
        * node: A Python ast node, representing the converted code.
        * deps: A set of strings, the fully qualified names of object
            dependencies that this node has.
  """
  # TODO(mdan): Get rid of this.
  node = gradients_function.transform(node)

  node = access.resolve(node)
  node = live_values.resolve(node, namespace, config.PYTHON_LITERALS)
  node = type_info.resolve(node, value_hints)

  # TODO(mdan): Factor out common elements.
  # These include:
  #   * keeping track of symbols that have been created
  #   * marking nodes (e.g. py_func wrappers) to suppress further processing

  node = print_functions.transform(node)
  node = call_trees.transform(node, namer, config.DEFAULT_UNCOMPILED_MODULES)
  node = control_flow.transform(node, namer)
  node = logical_expressions.transform(node)
  node = side_effect_guards.transform(node, namer)

  return node
Exemple #7
0
def node_to_graph(node, namer, namespace, value_hints):
    """Convert Python code to equivalent TF graph mode code.

  Args:
    node: A Python AST node representing the code to convert.
    namer: A naming.Namer object.
    namespace: Dict mapping symbol names to their corresponding live objects.
    value_hints: A dict containing value hints for symbols like function
        parameters.

  Returns:
    A tuple (node, deps):
        * node: A Python ast node, representing the converted code.
        * deps: A set of strings, the fully qualified names of object
            dependencies that this node has.
  """
    node = access.resolve(node)
    node = live_values.resolve(node, namespace, config.PYTHON_LITERALS)
    node = type_info.resolve(node, value_hints)

    # TODO(mdan): Factor out common elements.
    # These include:
    #   * keeping track of symbols that have been created
    #   * marking nodes (e.g. py_func wrappers) to suppress further processing

    node = for_canonicalization.transform(node, namer)
    node = builtin_functions.transform(node)

    # The transformation steps above insert new variables. Although less
    # efficient, it is most robust to re-run the analysis.
    # We also need to ensure the namespace contains any new references that may
    # have been created.
    namespace['len'] = len
    namespace['print'] = print

    node = access.resolve(node)
    node = live_values.resolve(node, namespace, config.PYTHON_LITERALS)
    node = type_info.resolve(node, value_hints)

    node = print_functions.transform(node)
    node = call_trees.transform(node, namer, config.DEFAULT_UNCOMPILED_MODULES)
    node = control_flow.transform(node, namer)
    node = logical_expressions.transform(node)
    node = side_effect_guards.transform(node, namer)

    return node
Exemple #8
0
    def test_parameter_class_members(self):
        def test_fn(opt):
            opt.minimize(0)

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'training': training}, {})
        with self.assertRaises(ValueError):
            node = type_info.resolve(node, None)
Exemple #9
0
    def test_nested_members(self):
        def test_fn():
            foo = training.GradientDescentOptimizer(0.1)
            foo.bar.baz()

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'training': training}, {})
        with self.assertRaises(ValueError):
            node = type_info.resolve(node, None)
  def test_parameter_class_members(self):

    def test_fn(opt):
      opt.minimize(0)

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'training': training}, {})
    with self.assertRaises(ValueError):
      node = type_info.resolve(node, None)
    def test_literals(self):
        def test_fn():
            return Foo  # pylint: disable=undefined-variable

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {}, {'Foo': 'bar'})

        retval_node = node.body[0].body[0].value
        self.assertEquals('bar', anno.getanno(retval_node, 'live_val'))
  def test_literals(self):

    def test_fn():
      return Foo  # pylint: disable=undefined-variable

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {}, {'Foo': 'bar'})

    retval_node = node.body[0].body[0].value
    self.assertEquals('bar', anno.getanno(retval_node, 'live_val'))
 def _parse_and_analyze(self,
                        test_fn,
                        namespace,
                        literals=None,
                        arg_types=None):
     literals = literals or {}
     arg_types = arg_types or {}
     node, source = parser.parse_entity(test_fn)
     ctx = context.EntityContext(namer=None,
                                 source_code=source,
                                 source_file=None,
                                 namespace=namespace,
                                 arg_values=None,
                                 arg_types=arg_types,
                                 recursive=True)
     node = access.resolve(node, ctx)
     node = live_values.resolve(node, ctx, literals)
     node = type_info.resolve(node, ctx)
     node = live_values.resolve(node, ctx, literals)
     return node
  def test_nested_members(self):

    def test_fn():
      foo = training.GradientDescentOptimizer(0.1)
      foo.bar.baz()

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'training': training}, {})
    with self.assertRaises(ValueError):
      node = type_info.resolve(node, None)
 def _parse_and_analyze(self,
                        test_fn,
                        namespace,
                        arg_types=None):
   node, source = parser.parse_entity(test_fn)
   ctx = context.EntityContext(
       namer=None,
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       owner_type=None,
       recursive=True,
       type_annotation_func=utils.set_element_type)
   node = qual_names.resolve(node)
   node = activity.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   node = type_info.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   return node
 def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
     ctx = context.EntityContext(namer=None,
                                 source_code=None,
                                 source_file=None,
                                 namespace=namespace,
                                 arg_values=None,
                                 arg_types=arg_types)
     node = parser.parse_object(test_fn)
     node = access.resolve(node)
     node = live_values.resolve(node, namespace, {})
     node = type_info.resolve(node, ctx)
     return node
  def test_attribute_names(self):

    def test_fn():
      return constant_op.constant(0)

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'constant_op': constant_op}, {})

    func_node = node.body[0].body[0].value.func
    self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val'))
    self.assertEquals((constant_op.__name__, 'constant'),
                      anno.getanno(func_node, 'fqn'))
    def test_attribute_names(self):
        def test_fn():
            return constant_op.constant(0)

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'constant_op': constant_op}, {})

        func_node = node.body[0].body[0].value.func
        self.assertEquals(constant_op.constant,
                          anno.getanno(func_node, 'live_val'))
        self.assertEquals((constant_op.__name__, 'constant'),
                          anno.getanno(func_node, 'fqn'))
Exemple #19
0
    def test_constructor_deta_dependent(self):
        def test_fn(x):
            if x > 0:
                opt = training.GradientDescentOptimizer(0.1)
            else:
                opt = training.GradientDescentOptimizer(0.01)
            opt.minimize(0)

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'training': training}, {})
        with self.assertRaises(ValueError):
            node = type_info.resolve(node, None)
Exemple #20
0
 def parse_and_analyze(self,
                       test_fn,
                       namespace,
                       namer=None,
                       arg_types=None,
                       include_type_analysis=True,
                       recursive=True):
     node, source = parser.parse_entity(test_fn)
     ctx = context.EntityContext(namer=namer,
                                 source_code=source,
                                 source_file=None,
                                 namespace=namespace,
                                 arg_values=None,
                                 arg_types=arg_types,
                                 recursive=recursive)
     node = access.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
     if include_type_analysis:
         node = type_info.resolve(node, ctx)
         node = live_values.resolve(node, ctx, {})
     self.ctx = ctx
     return node
 def _parse_and_analyze(self,
                        test_fn,
                        namespace,
                        literals=None,
                        arg_types=None):
   literals = literals or {}
   arg_types = arg_types or {}
   node, source = parser.parse_entity(test_fn)
   ctx = context.EntityContext(
       namer=None,
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       recursive=True)
   node = qual_names.resolve(node)
   node = activity.resolve(node, ctx)
   node = live_values.resolve(node, ctx, literals)
   node = type_info.resolve(node, ctx)
   node = live_values.resolve(node, ctx, literals)
   return node
Exemple #22
0
    def test_class_members(self):
        def test_fn():
            opt = training.GradientDescentOptimizer(0.1)
            opt.minimize(0)

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'training': training}, {})
        node = type_info.resolve(node, None)

        attr_call_node = node.body[0].body[1].value.func
        self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
                          anno.getanno(attr_call_node, 'type_fqn'))
Exemple #23
0
    def test_function_variables(self):
        def bar():
            pass

        def test_fn():
            foo = bar
            foo()

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'bar': bar}, {})
        with self.assertRaises(ValueError):
            node = type_info.resolve(node, None)
    def test_namespace(self):
        def foo():
            return 'bar'

        def test_fn():
            return foo()

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'foo': foo}, {})

        func_node = node.body[0].body[0].value.func
        self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
        self.assertEquals(('foo', ), anno.getanno(func_node, 'fqn'))
  def test_class_members(self):

    def test_fn():
      opt = training.GradientDescentOptimizer(0.1)
      opt.minimize(0)

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'training': training}, {})
    node = type_info.resolve(node, None)

    attr_call_node = node.body[0].body[1].value.func
    self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
                      anno.getanno(attr_call_node, 'type_fqn'))
  def test_constructor_deta_dependent(self):

    def test_fn(x):
      if x > 0:
        opt = training.GradientDescentOptimizer(0.1)
      else:
        opt = training.GradientDescentOptimizer(0.01)
      opt.minimize(0)

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'training': training}, {})
    with self.assertRaises(ValueError):
      node = type_info.resolve(node, {})
  def test_function_variables(self):

    def bar():
      pass

    def test_fn():
      foo = bar
      foo()

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'bar': bar}, {})
    with self.assertRaises(ValueError):
      node = type_info.resolve(node, None)
Exemple #28
0
    def test_constructor_detection(self):
        def test_fn():
            opt = training.GradientDescentOptimizer(0.1)
            return opt

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'training': training}, {})
        node = type_info.resolve(node, None)

        call_node = node.body[0].body[0].value
        self.assertEquals(training.GradientDescentOptimizer,
                          anno.getanno(call_node, 'type'))
        self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
                          anno.getanno(call_node, 'type_fqn'))
  def test_namespace(self):

    def foo():
      return 'bar'

    def test_fn():
      return foo()

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'foo': foo}, {})

    func_node = node.body[0].body[0].value.func
    self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
    self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
  def test_constructor_detection(self):

    def test_fn():
      opt = training.GradientDescentOptimizer(0.1)
      return opt

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'training': training}, {})
    node = type_info.resolve(node, None)

    call_node = node.body[0].body[0].value
    self.assertEquals(training.GradientDescentOptimizer,
                      anno.getanno(call_node, 'type'))
    self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
                      anno.getanno(call_node, 'type_fqn'))
Exemple #31
0
    def test_parameter_class_members_with_value_hints(self):
        def test_fn(opt):
            opt.minimize(0)

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'training': training}, {})
        node = type_info.resolve(
            node, {
                'opt': (('%s.GradientDescentOptimizer' % training.__name__),
                        training.GradientDescentOptimizer(0.1))
            })

        attr_call_node = node.body[0].body[0].value.func
        self.assertEquals(
            training.__name__.split('.') + ['GradientDescentOptimizer'],
            anno.getanno(attr_call_node, 'type_fqn'))
  def test_parameter_class_members_with_value_hints(self):

    def test_fn(opt):
      opt.minimize(0)

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'training': training}, {})
    node = type_info.resolve(
        node, {
            'opt': (('%s.GradientDescentOptimizer' % training.__name__),
                    training.GradientDescentOptimizer(0.1))
        })

    attr_call_node = node.body[0].body[0].value.func
    self.assertEquals(
        training.__name__.split('.') + ['GradientDescentOptimizer'],
        anno.getanno(attr_call_node, 'type_fqn'))
 def parse_and_analyze(self,
                       test_fn,
                       namespace,
                       arg_types=None,
                       include_type_analysis=True):
   ctx = context.EntityContext(
       namer=None,
       source_code=None,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types)
   node = parser.parse_object(test_fn)
   node = access.resolve(node)
   node = live_values.resolve(node, namespace, {})
   if include_type_analysis:
     node = type_info.resolve(node, ctx)
   return node
  def test_class_members_in_with_stmt(self):

    def test_fn(x):
      with session.Session() as sess:
        sess.run(x)

    node = parser.parse_object(test_fn)
    node = access.resolve(node)
    node = live_values.resolve(node, {'session': session}, {})
    node = type_info.resolve(node, {})

    constructor_call = node.body[0].body[0].items[0].context_expr
    self.assertEquals(session.Session, anno.getanno(constructor_call, 'type'))
    self.assertEquals((session.__name__, 'Session'),
                      anno.getanno(constructor_call, 'type_fqn'))

    member_call = node.body[0].body[0].body[0].value.func
    self.assertEquals((session.__name__, 'Session'),
                      anno.getanno(member_call, 'type_fqn'))
Exemple #35
0
    def test_class_members_in_with_stmt(self):
        def test_fn(x):
            with session.Session() as sess:
                sess.run(x)

        node = parser.parse_object(test_fn)
        node = access.resolve(node)
        node = live_values.resolve(node, {'session': session}, {})
        node = type_info.resolve(node, None)

        constructor_call = node.body[0].body[0].items[0].context_expr
        self.assertEquals(session.Session,
                          anno.getanno(constructor_call, 'type'))
        self.assertEquals((session.__name__, 'Session'),
                          anno.getanno(constructor_call, 'type_fqn'))

        member_call = node.body[0].body[0].body[0].value.func
        self.assertEquals((session.__name__, 'Session'),
                          anno.getanno(member_call, 'type_fqn'))
Exemple #36
0
def _static_analysis_pass(node, source, f, namespace, value_hints):
  node = access.resolve(node)
  node = live_values.resolve(node, namespace, config.PYTHON_LITERALS)
  node = type_info.resolve(node, source, f, value_hints)
  return node
Exemple #37
0
def _static_analysis_pass(node, ctx):
  node = qual_names.resolve(node)
  node = activity.resolve(node, ctx, None)
  node = live_values.resolve(node, ctx, config.PYTHON_LITERALS)
  node = type_info.resolve(node, ctx)
  return node
def _static_analysis_pass(node, ctx):
  node = access.resolve(node)
  node = live_values.resolve(node, ctx.namespace, config.PYTHON_LITERALS)
  node = type_info.resolve(node, ctx)
  return node
 def _parse_and_analyze(self, test_fn, namespace):
   node = parser.parse_object(test_fn)
   node = access.resolve(node)
   node = live_values.resolve(node, namespace, {})
   node = type_info.resolve(node, {})
   return node
Exemple #40
0
 def _parse_and_analyze(self, test_fn, namespace):
     node = parser.parse_object(test_fn)
     node = access.resolve(node)
     node = live_values.resolve(node, namespace, {})
     node = type_info.resolve(node, None)
     return node