def test_argument_target(self): def test_fn(a): directives.set_element_type(a, 1, shape=2) node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS) d = def_.directives[directives.set_element_type] self.assertEqual(d['dtype'].n, 1) self.assertEqual(d['shape'].n, 2)
def test_local_target(self): def test_fn(): l = [] string_var = 0 directives.set_element_type(l, 'a', string_var) node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) def_, = anno.getanno(node.body[0].targets[0], anno.Static.DEFINITIONS) d = def_.directives[directives.set_element_type] self.assertEqual(d['dtype'].s, 'a') self.assertEqual(d['shape'].id, 'string_var')
def test_loop_target(self): def test_fn(): a = True while True: directives.set_loop_options(parallel_iterations=10, back_prop=a) node, ctx = self.prepare(test_fn, {'directives': directives}) node = directives_converter.transform(node, ctx) d = anno.getanno(node.body[1], AgAnno.DIRECTIVES) d = d[directives.set_loop_options] self.assertEqual(d['parallel_iterations'].n, 10) self.assertEqual(d['back_prop'].id, 'a') self.assertNotIn('swap_memory', d)