def transform_ast(self, node, ctx): # TODO(mdan): Insert list_comprehensions somewhere. unsupported_features_checker.verify(node) # Run initial analysis. graphs = cfg.build(node) node = qual_names.resolve(node) node = activity.resolve(node, ctx, None) node = reaching_definitions.resolve(node, ctx, graphs) anno.dup( node, { anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, }, ) node = functions.transform(node, ctx) node = directives.transform(node, ctx) node = break_statements.transform(node, ctx) if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): node = asserts.transform(node, ctx) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. node = continue_statements.transform(node, ctx) node = return_statements.transform(node, ctx) if ctx.user.options.uses(converter.Feature.LISTS): node = lists.transform(node, ctx) node = slices.transform(node, ctx) node = call_trees.transform(node, ctx) node = control_flow.transform(node, ctx) node = conditional_expressions.transform(node, ctx) node = logical_expressions.transform(node, ctx) return node
def test_loop_target_with_no_loop(self): def test_fn(): directives.set_loop_options() node, ctx = self.prepare(test_fn, {'directives': directives}) with self.assertRaisesRegexp(ValueError, 'must be used inside a statement'): node = directives_converter.transform(node, ctx)
def test_argument_target(self): def test_fn(a): directives.set_element_type(a, 1, shape=2) node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) d = def_.directives[directives.set_element_type] self.assertEqual(d['dtype'].n, 1) self.assertEqual(d['shape'].n, 2)
def test_loop_target_not_first(self): def test_fn(): a = 1 while True: a = 2 directives.set_loop_options(parallel_iterations=10, back_prop=a) node, ctx = self.prepare(test_fn, {'directives': directives}) with self.assertRaisesRegexp(ValueError, 'must be the first statement'): node = directives_converter.transform(node, ctx)
def test_local_target(self): def test_fn(): l = [] string_var = 0 directives.set_element_type(l, 'a', string_var) node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) def_, = anno.getanno(node.body[0].targets[0], anno.Static.DEFINITIONS) d = def_.directives[directives.set_element_type] self.assertEqual(d['dtype'].s, 'a') self.assertEqual(d['shape'].id, 'string_var')
def test_value_verification_does_not_trigger_properties(self): class TestClass(object): @property def b(self): raise ValueError('This should never be evaluated') tc = TestClass() def test_fn(): return tc.b + 1 node, ctx = self.prepare(test_fn, {'tc': tc}) node = directives_converter.transform(node, ctx) self.assertIsNotNone(node)
def test_loop_target(self): def test_fn(): a = True while True: directives.set_loop_options(parallel_iterations=10, back_prop=a) node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) d = anno.getanno(node.body[1], AgAnno.DIRECTIVES) d = d[directives.set_loop_options] self.assertEqual(d['parallel_iterations'].n, 10) self.assertEqual(d['back_prop'].id, 'a') self.assertNotIn('swap_memory', d)
def test_loop_target(self): def test_fn(): a = True while True: directives.set_loop_options(parallel_iterations=10, back_prop=a) node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) d = anno.getanno(node.body[1], AgAnno.DIRECTIVES) d = d[directives.set_loop_options] self.assertEqual(d['parallel_iterations'].value, 10) self.assertEqual(d['back_prop'].id, 'a') self.assertNotIn('swap_memory', d)
def test_value_verification_does_not_trigger_getattr(self): class TestClass(object): def __init__(self): self.getattr_called = False def __getattr__(self, _): # Note: seems that any exception raised here is absorbed by hasattr. # So we can't call test.fail or raise. self.getattr_called = True tc = TestClass() def test_fn(): return tc.b + 1 node, ctx = self.prepare(test_fn, {'tc': tc}) node = directives_converter.transform(node, ctx) self.assertIsNotNone(node) self.assertFalse(tc.getattr_called)
def transform_ast(self, node, ctx): unsupported_features_checker.verify(node) node = self.initial_analysis(node, ctx) node = functions.transform(node, ctx) node = directives.transform(node, ctx) node = break_statements.transform(node, ctx) if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): node = asserts.transform(node, ctx) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. node = continue_statements.transform(node, ctx) node = return_statements.transform(node, ctx) if ctx.user.options.uses(converter.Feature.LISTS): node = lists.transform(node, ctx) node = slices.transform(node, ctx) node = call_trees.transform(node, ctx) node = control_flow.transform(node, ctx) node = conditional_expressions.transform(node, ctx) node = logical_expressions.transform(node, ctx) node = variables.transform(node, ctx) return node