def transform_ast(self, node, ctx):
        # TODO(mdan): Insert list_comprehensions somewhere.
        unsupported_features_checker.verify(node)

        # Run initial analysis.
        graphs = cfg.build(node)
        node = qual_names.resolve(node)
        node = activity.resolve(node, ctx, None)
        node = reaching_definitions.resolve(node, ctx, graphs)
        anno.dup(
            node,
            {
                anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
            },
        )

        node = functions.transform(node, ctx)
        node = directives.transform(node, ctx)
        node = break_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
            node = asserts.transform(node, ctx)
        # Note: sequencing continue canonicalization before for loop one avoids
        # dealing with the extra loop increment operation that the for
        # canonicalization creates.
        node = continue_statements.transform(node, ctx)
        node = return_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.LISTS):
            node = lists.transform(node, ctx)
            node = slices.transform(node, ctx)
        node = call_trees.transform(node, ctx)
        node = control_flow.transform(node, ctx)
        node = conditional_expressions.transform(node, ctx)
        node = logical_expressions.transform(node, ctx)
        return node
Exemple #2
0
    def test_list_stack(self):
        def test_fn():
            l = [1, 2, 3]
            return tf.stack(l)

        node, ctx = self.prepare(test_fn, {})
        def_, = anno.getanno(node.body[0].targets[0],
                             anno.Static.ORIG_DEFINITIONS)
        def_.directives[directives.set_element_type] = {
            'dtype': parser.parse_expression('tf.int32')
        }
        node = lists.transform(node, ctx)

        with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
            with self.cached_session() as sess:
                self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3])
Exemple #3
0
  def test_list_stack(self):

    def test_fn():
      l = [1, 2, 3]
      return tf.stack(l)

    node, ctx = self.prepare(test_fn, {})
    def_, = anno.getanno(node.body[0].targets[0],
                         anno.Static.ORIG_DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.int32')
    }
    node = lists.transform(node, ctx)

    with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
      with self.cached_session() as sess:
        self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3])
Exemple #4
0
    def test_list_pop(self):
        def test_fn():
            l = special_functions.tensor_list([1, 2, 3])
            s = l.pop()
            return s, l

        ns = {'special_functions': special_functions}
        node, ctx = self.prepare(test_fn, ns)
        def_, = anno.getanno(node.body[0].targets[0],
                             anno.Static.ORIG_DEFINITIONS)
        def_.directives[directives.set_element_type] = {
            'dtype': parser.parse_expression('tf.int32'),
            'shape': parser.parse_expression('()'),
        }
        node = lists.transform(node, ctx)

        with self.compiled(node, ns, dtypes.int32) as result:
            with self.cached_session() as sess:
                ts, tl = result.test_fn()
                r = list_ops.tensor_list_stack(tl, dtypes.int32)
                self.assertAllEqual(self.evaluate(r), [1, 2])
                self.assertAllEqual(self.evaluate(ts), 3)
Exemple #5
0
    def transform_ast(self, node, ctx):
        unsupported_features_checker.verify(node)
        node = self.initial_analysis(node, ctx)

        node = functions.transform(node, ctx)
        node = directives.transform(node, ctx)
        node = break_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
            node = asserts.transform(node, ctx)
        # Note: sequencing continue canonicalization before for loop one avoids
        # dealing with the extra loop increment operation that the for
        # canonicalization creates.
        node = continue_statements.transform(node, ctx)
        node = return_statements.transform(node, ctx)
        if ctx.user.options.uses(converter.Feature.LISTS):
            node = lists.transform(node, ctx)
            node = slices.transform(node, ctx)
        node = call_trees.transform(node, ctx)
        node = control_flow.transform(node, ctx)
        node = conditional_expressions.transform(node, ctx)
        node = logical_expressions.transform(node, ctx)
        node = variables.transform(node, ctx)
        return node
Exemple #6
0
  def test_list_pop(self):

    def test_fn():
      l = special_functions.tensor_list([1, 2, 3])
      s = l.pop()
      return s, l

    ns = {'special_functions': special_functions}
    node, ctx = self.prepare(test_fn, ns)
    def_, = anno.getanno(node.body[0].targets[0],
                         anno.Static.ORIG_DEFINITIONS)
    def_.directives[directives.set_element_type] = {
        'dtype': parser.parse_expression('tf.int32'),
        'shape': parser.parse_expression('()'),
    }
    node = lists.transform(node, ctx)

    with self.compiled(node, ns, dtypes.int32) as result:
      with self.cached_session() as sess:
        ts, tl = result.test_fn()
        r = list_ops.tensor_list_stack(tl, dtypes.int32)
        self.assertAllEqual(self.evaluate(r), [1, 2])
        self.assertAllEqual(self.evaluate(ts), 3)