Exemplo n.º 1
0
    def transform_ast(self, node, ctx):
        # TODO(mdan): Insert list_comprehensions somewhere.
        unsupported_features_checker.verify(node)

        # Run initial analysis.
        graphs = cfg.build(node)
        node = qual_names.resolve(node)
        node = activity.resolve(node, ctx, None)
        node = reaching_definitions.resolve(node, ctx, graphs)
        anno.dup(
            node,
            {
                anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
            },
        )

        node = functions.transform(node, ctx)
        node = directives.transform(node, ctx)
        node = break_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
            node = asserts.transform(node, ctx)
        # Note: sequencing continue canonicalization before for loop one avoids
        # dealing with the extra loop increment operation that the for
        # canonicalization creates.
        node = continue_statements.transform(node, ctx)
        node = return_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.LISTS):
            node = lists.transform(node, ctx)
            node = slices.transform(node, ctx)
        node = call_trees.transform(node, ctx)
        node = control_flow.transform(node, ctx)
        node = conditional_expressions.transform(node, ctx)
        node = logical_expressions.transform(node, ctx)
        return node
Exemplo n.º 2
0
    def test_index_access_multiple_definitions(self):
        def test_fn(l):
            if l:
                l = []
            return l[1]

        node, ctx = self.prepare(test_fn, {})
        def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
        def_.directives[directives.set_element_type] = {
            'dtype': parser.parse_expression('tf.int32')
        }
        def_, = anno.getanno(node.body[0].body[0].targets[0],
                             anno.Static.DEFINITIONS)
        def_.directives[directives.set_element_type] = {
            'dtype': parser.parse_expression('tf.float32')
        }
        with self.assertRaises(ValueError):
            slices.transform(node, ctx)
Exemplo n.º 3
0
  def test_index_access_multiple_definitions(self):

    def test_fn(l):
      if l:
        l = []
      return l[1]

    node, ctx = self.prepare(test_fn, {})
    def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.int32')
    }
    def_, = anno.getanno(node.body[0].body[0].targets[0],
                         anno.Static.DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.float32')
    }
    with self.assertRaises(transformer.AutographParseError):
      slices.transform(node, ctx)
Exemplo n.º 4
0
  def test_index_access(self):

    def test_fn(l):
      return l[1]

    node, ctx = self.prepare(test_fn, {})
    def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.int32')
    }
    node = slices.transform(node, ctx)

    with self.compiled(node, {}, dtypes.int32) as result:
      with self.cached_session() as sess:
        tl = list_ops.tensor_list_from_tensor(
            [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
        y = result.test_fn(tl)
        self.assertEqual(2, sess.run(y))
Exemplo n.º 5
0
  def test_index_access(self):

    def test_fn(l):
      return l[1]

    node, ctx = self.prepare(test_fn, {})
    def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.int32')
    }
    node = slices.transform(node, ctx)

    with self.compiled(node, {}, dtypes.int32) as result:
      with self.cached_session() as sess:
        tl = list_ops.tensor_list_from_tensor(
            [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
        y = result.test_fn(tl)
        self.assertEqual(2, self.evaluate(y))
Exemplo n.º 6
0
    def transform_ast(self, node, ctx):
        unsupported_features_checker.verify(node)
        node = self.initial_analysis(node, ctx)

        node = functions.transform(node, ctx)
        node = directives.transform(node, ctx)
        node = break_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
            node = asserts.transform(node, ctx)
        # Note: sequencing continue canonicalization before for loop one avoids
        # dealing with the extra loop increment operation that the for
        # canonicalization creates.
        node = continue_statements.transform(node, ctx)
        node = return_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.LISTS):
            node = lists.transform(node, ctx)
            node = slices.transform(node, ctx)
        node = call_trees.transform(node, ctx)
        node = control_flow.transform(node, ctx)
        node = conditional_expressions.transform(node, ctx)
        node = logical_expressions.transform(node, ctx)
        node = variables.transform(node, ctx)
        return node