예제 #1
0
    def test_argument_target(self):
        def test_fn(a):
            directives.set_element_type(a, 1, shape=2)

        node, ctx = self.prepare(test_fn, {'directives': directives})
        node = directives_converter.transform(node, ctx)

        def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
        d = def_.directives[directives.set_element_type]
        self.assertEqual(d['dtype'].n, 1)
        self.assertEqual(d['shape'].n, 2)
예제 #2
0
  def test_argument_target(self):

    def test_fn(a):
      directives.set_element_type(a, 1, shape=2)

    node, ctx = self.prepare(test_fn, {'directives': directives})
    node = directives_converter.transform(node, ctx)

    def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
    d = def_.directives[directives.set_element_type]
    self.assertEqual(d['dtype'].n, 1)
    self.assertEqual(d['shape'].n, 2)
예제 #3
0
    def test_local_target(self):
        def test_fn():
            l = []
            string_var = 0
            directives.set_element_type(l, 'a', string_var)

        node, ctx = self.prepare(test_fn, {'directives': directives})
        node = directives_converter.transform(node, ctx)

        def_, = anno.getanno(node.body[0].targets[0], anno.Static.DEFINITIONS)
        d = def_.directives[directives.set_element_type]
        self.assertEqual(d['dtype'].s, 'a')
        self.assertEqual(d['shape'].id, 'string_var')
예제 #4
0
    def test_loop_target(self):
        def test_fn():
            a = True
            while True:
                directives.set_loop_options(parallel_iterations=10,
                                            back_prop=a)

        node, ctx = self.prepare(test_fn, {'directives': directives})
        node = directives_converter.transform(node, ctx)

        d = anno.getanno(node.body[1], AgAnno.DIRECTIVES)
        d = d[directives.set_loop_options]
        self.assertEqual(d['parallel_iterations'].n, 10)
        self.assertEqual(d['back_prop'].id, 'a')
        self.assertNotIn('swap_memory', d)
예제 #5
0
  def test_loop_target(self):

    def test_fn():
      a = True
      while True:
        directives.set_loop_options(parallel_iterations=10, back_prop=a)

    node, ctx = self.prepare(test_fn, {'directives': directives})
    node = directives_converter.transform(node, ctx)

    d = anno.getanno(node.body[1], AgAnno.DIRECTIVES)
    d = d[directives.set_loop_options]
    self.assertEqual(d['parallel_iterations'].n, 10)
    self.assertEqual(d['back_prop'].id, 'a')
    self.assertNotIn('swap_memory', d)
예제 #6
0
  def test_local_target(self):

    def test_fn():
      l = []
      string_var = 0
      directives.set_element_type(l, 'a', string_var)

    node, ctx = self.prepare(test_fn, {'directives': directives})
    node = directives_converter.transform(node, ctx)

    def_, = anno.getanno(node.body[0].targets[0],
                         anno.Static.DEFINITIONS)
    d = def_.directives[directives.set_element_type]
    self.assertEqual(d['dtype'].s, 'a')
    self.assertEqual(d['shape'].id, 'string_var')