Exemplo n.º 1
0
    def _torch_to_tf(self, batch_size, max_seq_len):
        """Benchmark pytorch converted to tensorflow graph implementation."""

        tf_dynamic_rnn = conversion.convert(
            torch_dynamic_rnn, pytorch, [functions, variables, control_flow])
        tf_create_rnn_cell = conversion.convert(
            _create_torch_rnn_cell, pytorch,
            [functions, variables, control_flow])
        tf_get_init_data = conversion.convert(
            _get_torch_inputs, pytorch, [functions, variables, control_flow])

        with tf.Graph().as_default():
            input_data, sequence_lengths = self._generate_fake_rnn_inputs(
                batch_size=batch_size, max_seq_len=max_seq_len)
            cell, init_state = tf_create_rnn_cell(batch_size)
            input_data, sequence_lengths = tf_get_init_data(
                input_data, sequence_lengths)
            rnn_output = tf_dynamic_rnn(cell, input_data, init_state,
                                        sequence_lengths)

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())

                def target():
                    sess.run(rnn_output)

                self.time_execution(('Torch2Flow', batch_size, max_seq_len),
                                    target,
                                    iter_volume=batch_size,
                                    iter_unit='examples',
                                    extras={
                                        'max_seq_len': max_seq_len,
                                        'batch_size': batch_size,
                                    })
Exemplo n.º 2
0
    def _numpy_to_tf(self, batch_size, max_seq_len, input_size, hidden_size):
        with tf.Graph().as_default():
            with tf.Session() as sess:
                tensors = dynamic_rnn_minimal.random_inputs_tf(
                    batch_size, max_seq_len, input_size, hidden_size)
                inputs, seq_len, w, b, init_state = tensors

                tf_from_np = conversion.convert(
                    dynamic_rnn_minimal.numpy, numpy_to_tf,
                    [variables, functions, control_flow])
                ops = tf_from_np(inputs, seq_len, w, b, init_state)

                def target():
                    sess.run(ops)

                self.time_execution(
                    ('NumPy_TF', batch_size, max_seq_len, input_size,
                     hidden_size),
                    target,
                    extras={
                        'max_seq_len': max_seq_len,
                        'batch_size': batch_size,
                        'input_size': input_size,
                        'hidden_size': hidden_size,
                    })
Exemplo n.º 3
0
    def test_overload_not_staged(self):
        """Calls on the overload module are not themselves overloaded."""

        # After conversion this function will contain function calls for variable
        # virtualization, we do not want these to be affected by function call
        # virtualization.
        def test_fn(y):
            x = 1 + y
            return x

        # Custom overloads that blow up on call virtualization.
        class MyOverloads(object):

            # There is no call in the test_fn that should be virtualized this ensures
            # that an exception is thrown if call virtualization is attempted.
            def call(self, func, args, keywords):  # pylint: disable=unused-argument
                self.fail('No call should be virtualized in this test.')

            def init(self, name):
                return py_defaults.init(name)

            def assign(self, lhs, rhs):
                return py_defaults.assign(lhs, rhs)

            def read(self, var):
                return py_defaults.read(var)

        overloads = MyOverloads()

        converted_func = conversion.convert(test_fn, overloads,
                                            [variables, functions])
        self.assertEqual(converted_func(5), 6)
Exemplo n.º 4
0
    def test_sequential_ifs(self, x):
        def test_fn(x):
            a = 0
            b = 0

            if x > 0:
                a = 1
            else:
                b = 1

            if x < 0:
                a = 2
            else:
                b = 2
            return a, b

        converted_fn = conversion.convert(test_fn, tf_,
                                          [variables, control_flow])

        with tf.Graph().as_default():
            with tf.Session() as sess:
                converted_result = converted_fn(tf.constant(x))
                unconverted_result = test_fn(x)
                self.assertEqual(sess.run(converted_result),
                                 unconverted_result)
Exemplo n.º 5
0
 def test_noop_dictionary_variables(self):
   self.assertListEqual(check_cond(1), [1])
   self.assertListEqual(check_cond(5), [2])
   converted_check_cond = conversion.convert(check_cond, dictionary_variables,
                                             [])
   self.assertListEqual(converted_check_cond(1), [1])
   self.assertListEqual(check_cond(5), [2])
Exemplo n.º 6
0
 def test_noop(self):
   self.assertListEqual(check_cond(1), [1])
   self.assertListEqual(check_cond(5), [2])
   converted_check_cond = conversion.convert(check_cond, py_defaults,
                                             [])
   self.assertListEqual(converted_check_cond(1), [1])
   self.assertListEqual(check_cond(5), [2])
Exemplo n.º 7
0
    def test_nested_if(self, x, y):
        def test_fn(x, y):
            a = 0
            b = 0
            c = 0
            d = 0
            if x > 0:
                if y > 0:
                    a = x
                else:
                    b = y
            else:
                if y > 0:
                    c = x
                else:
                    d = y
            return a, b, c, d

        converted_fn = conversion.convert(test_fn, tf_,
                                          [variables, control_flow])

        with tf.Graph().as_default():
            with tf.Session() as sess:
                converted_result = converted_fn(tf.constant(x), tf.constant(y))
                unconverted_result = test_fn(x, y)
                self.assertEqual(sess.run(converted_result),
                                 unconverted_result)
Exemplo n.º 8
0
    def test_basic_control_flow(self, x):
        def test_fn(n):
            i = 0
            s = 0
            while i < n:
                if i > s:
                    u = i - s
                    j = 0
                else:
                    u = i + s
                    j = 1
                s = s + u + j
                i = i + 1
            return s, i

        converted_fn = conversion.convert(test_fn, pytorch,
                                          [variables, control_flow])

        with tf.Graph().as_default():
            with tf.Session() as sess:
                converted_result = converted_fn(
                    torch.tensor(x, dtype=torch.int32))
                unconverted_result = test_fn(x)
                self.assertEqual(sess.run(converted_result),
                                 unconverted_result)
Exemplo n.º 9
0
    def test_eight_queens(self):
        # See https://ericpony.github.io/z3py-tutorial/guide-examples.htm

        def test_fn(queens):
            diagonals = []
            for i in range(8):
                for j in range(i):
                    result = None
                    if i == j:
                        result = True
                    else:
                        result = queens[i] - queens[j] != i - j and queens[
                            i] - queens[j] != j - 1

                    diagonals.append(result)

            return diagonals

        queens = [z3.Int('queens_%i' % (i + 1)) for i in range(8)]
        ranks = [z3.And(1 <= queens[i], queens[i] <= 8) for i in range(8)]
        files = [z3.Distinct(queens)]
        converted_fn = conversion.convert(
            test_fn, z3py, [logical_ops, variables, control_flow])
        diagonals = converted_fn(queens)
        self.assertTrue(can_solve(ranks + files + diagonals))
Exemplo n.º 10
0
  def test_nested_cond(self, x):

    def fun(x):
      res = 0
      if x < 2:
        res = lax.mul(2, x)
      else:
        if x < 5:
          res = lax.mul(3, x)
        else:
          res = lax.mul(4, x)
      return res

    @jax.api.jit
    def cfun(x):
      def inner_cond(x):
        return lax.cond(
            lax.lt(x, 5),
            x,
            lambda x: lax.mul(3, x),
            4,
            lambda y: lax.mul(y, x),
        )

      return lax.cond(lax.lt(x, 2), x, lambda x: lax.mul(2, x), x, inner_cond)

    converted_fn = conversion.convert(fun, jax_, [variables, control_flow])
    self.assertEqual(cfun(x), converted_fn(x))
Exemplo n.º 11
0
    def test_for_multiple_targets(self, iter_):
        def test_fn(iter_):
            res = []

            for x, y in iter_:
                res.append((x, y))

            return res

        converted_fn = conversion.convert(test_fn, py_defaults, [variables])
        self.assertEqual(converted_fn(iter_), test_fn(iter_))
Exemplo n.º 12
0
    def test_for_loop(self, x):
        def test_fn(n):
            sum_ = 0

            for i in range(n):
                sum_ = sum_ + i

            return sum_

        converted_fn = conversion.convert(test_fn, py_defaults, [variables])
        self.assertEqual(converted_fn(x), test_fn(x))
Exemplo n.º 13
0
    def test_known_function_swapped(self):

        # add is overloaded in call_swapping to return x + y + 1 for ints
        def test_fn(x, y):
            return call_swapping.add(x, y)

        converted_func = conversion.convert(test_fn, call_swapping,
                                            [functions])
        self.assertEqual(converted_func(1, 2), test_fn(1, 2) + 1)
        self.assertEqual(converted_func('hello ', 'world'),
                         test_fn('hello ', 'world'))
Exemplo n.º 14
0
    def test_for_loop_return_target(self, x):
        def test_fn(n):
            sum_ = 0

            for i in range(n):
                sum_ = sum_ + i

            return sum_ + i  # pylint: disable=undefined-loop-variable

        converted_fn = conversion.convert(test_fn, py_defaults, [variables])
        self.assertEqual(converted_fn(x), test_fn(x))
Exemplo n.º 15
0
    def test_for_loop_return_target_le_0(self, x):
        def test_fn(n):
            sum_ = 0

            for i in range(n):
                sum_ = sum_ + i

            return sum_ + i  # pylint: disable=undefined-loop-variable

        converted_fn = conversion.convert(test_fn, py_defaults, [variables])
        with self.assertRaises(py_defaults.PyctUnboundLocalError):
            converted_fn(x)
Exemplo n.º 16
0
    def test_if(self, p, q, r):
        def test_fn(a, b, c):
            result = None
            if a:
                result = b
            else:
                result = c
            return result

        converted_fn = conversion.convert(test_fn, z3py,
                                          [variables, control_flow])
        self.assertTrue(prove(z3.If(p, q, r) == converted_fn(p, q, r)))
Exemplo n.º 17
0
    def test_if_no_else(self):
        def test_fn(n):
            v = []

            if n > 0:
                v.append(n)

            return v

        converted_fn = conversion.convert(test_fn, py_defaults, [control_flow])
        for i in [0, 1]:
            self.assertEqual(converted_fn(i), test_fn(i))
Exemplo n.º 18
0
    def test_for_parameterized_noop(self, iter_):
        def test_fn(iter_):
            res = []

            for x, y in iter_:
                res.append((x, y))

            return res

        converted_fn = conversion.convert(test_fn, py_defaults,
                                          [variables, control_flow])
        self.assertEqual(converted_fn(iter_), test_fn(iter_))
Exemplo n.º 19
0
    def test_for_noop(self, x):
        def test_fn(n):
            res = []

            for i in range(n):
                res.append(i)

            return res

        converted_fn = conversion.convert(test_fn, py_defaults,
                                          [variables, control_flow])
        self.assertEqual(converted_fn(x), test_fn(x))
Exemplo n.º 20
0
    def test_unknown_function_not_swapped(self):

        # no overload for adder exists in call_swapping
        def adder(x, y):
            return x + y

        def test_fn(x, y):
            return adder(x, y)

        converted_func = conversion.convert(test_fn, call_swapping,
                                            [functions])
        self.assertEqual(converted_func(1, 2), test_fn(1, 2))
        self.assertEqual(converted_func('hello ', 'world'),
                         test_fn('hello ', 'world'))
Exemplo n.º 21
0
    def test_noop_while(self):
        def test_fn(n):
            res = []
            i = 0

            while i < n:
                res.append(i)
                i = i + 1

            return res

        converted_fn = conversion.convert(test_fn, py_defaults,
                                          [variables, control_flow])
        self.assertEqual(converted_fn(5), test_fn(5))
Exemplo n.º 22
0
  def test_while_basic(self, i):
    limit = 10

    def test_fn(init):
      count = 0
      while init < limit:
        init = init + 1
        count = count + 1
      return count

    converted_fn = conversion.convert(test_fn, jax_, [variables, control_flow])
    jitted_fn = jax.api.jit(converted_fn)
    self.assertEqual(test_fn(i), limit - i)
    self.assertEqual(test_fn(i), converted_fn(i))
    self.assertEqual(test_fn(i), jitted_fn(i))
Exemplo n.º 23
0
  def test_jit_after_conversion_cond(self):

    def f(x):
      res = 0
      if x < 3:
        res = 3. * x**2
      else:
        res = -4. * x
      return res

    converted_fn = conversion.convert(f, jax_, [variables, control_flow])
    self.assertEqual(f(2), converted_fn(2))

    jitted_fn = jax.api.jit(converted_fn)
    self.assertEqual(f(2), jitted_fn(2))
Exemplo n.º 24
0
    def test_very_nested_if(self, x):
        def test_fn(x):
            a = 0
            b = 0
            c = 0
            d = 0
            e = 0
            f = 0
            g = 0
            h = 0
            then_branch = 0
            else_branch = 0

            if x > 0:
                if x > 2:
                    if x > 4:
                        a = 1
                    else:
                        b = 1
                else:
                    if x > 1:
                        c = 1
                    else:
                        d = 1
                then_branch = 1
            else:
                if x < -2:
                    if x < -4:
                        e = 1
                    else:
                        f = 1
                else:
                    if x < -1:
                        g = 1
                    else:
                        h = 1
                else_branch = 1
            return a, b, c, d, e, f, g, h, then_branch, else_branch

        converted_fn = conversion.convert(test_fn, tf_,
                                          [variables, control_flow])

        with tf.Graph().as_default():
            with tf.Session() as sess:
                converted_result = converted_fn(tf.constant(x))
                unconverted_result = test_fn(x)
                self.assertEqual(sess.run(converted_result),
                                 unconverted_result)
Exemplo n.º 25
0
  def test_if_basic(self, x):

    def test_fn(n):
      a = 0
      b = 0
      if n > 0:
        a = n
      else:
        b = n
      return a, b

    converted_fn = conversion.convert(test_fn, jax_, [variables, control_flow])
    jax_converted = jax.jit(converted_fn)
    converted_result = jax_converted(x)
    unconverted_result = test_fn(x)
    self.assertEqual(converted_result, unconverted_result)
Exemplo n.º 26
0
    def test_if_tuple(self, p, q, r):
        def test_fn(a, b, c):
            result = None
            test_result = None
            if a:
                test_result = c
                result = b
            else:
                result = c
                test_result = b
            return result, test_result

        converted_fn = conversion.convert(test_fn, z3py,
                                          [variables, control_flow])
        a, b = converted_fn(p, q, r)
        self.assertTrue(
            prove(z3.If(p, z3.And(q, r), z3.And(q, r)) == z3.And(a, b)))
Exemplo n.º 27
0
  def test_very_nested_if(self, x):

    def test_fn(x):
      a = 0
      b = 0
      c = 0
      d = 0
      e = 0
      f = 0
      g = 0
      h = 0
      then_branch = 0
      else_branch = 0

      if x > 0:
        if x > 2:
          if x > 4:
            a = 1
          else:
            b = 1
        else:
          if x > 1:
            c = 1
          else:
            d = 1
        then_branch = 1
      else:
        if x < -2:
          if x < -4:
            e = 1
          else:
            f = 1
        else:
          if x < -1:
            g = 1
          else:
            h = 1
        else_branch = 1
      return a, b, c, d, e, f, g, h, then_branch, else_branch

    converted_fn = conversion.convert(test_fn, jax_, [variables, control_flow])
    jax_converted = jax.jit(converted_fn)
    converted_result = jax_converted(x)
    unconverted_result = test_fn(x)
    self.assertEqual(converted_result, unconverted_result)
Exemplo n.º 28
0
    def test_eight_queens_optimized(self):
        def test_fn():
            queens = [z3.Int('queens_%i' % (i + 1)) for i in range(8)]
            ranks = [1 <= queens[i] and queens[i] <= 8 for i in range(8)]
            files = [z3.Distinct(queens)]
            diagonals = []
            for i in range(8):
                for j in range(i):
                    if i != j:
                        diagonals.append(
                            abs(queens[i] - queens[j]) != abs(i - j))

            return ranks, files, diagonals

        converted_fn = conversion.convert(test_fn, z3py,
                                          [logical_ops, functions])
        ranks, files, diagonals = converted_fn()
        self.assertTrue(can_solve(ranks + files + diagonals))
Exemplo n.º 29
0
    def test_while_basic(self, x):
        def test_fn(n):
            i = 0
            s = 0
            while i < n:
                s = s + i
                i = i + 1
            return s, i, n

        converted_fn = conversion.convert(test_fn, tf_,
                                          [variables, control_flow])

        with tf.Graph().as_default():
            with tf.Session() as sess:
                converted_result = converted_fn(tf.constant(x))
                unconverted_result = test_fn(x)
                self.assertEqual(sess.run(converted_result),
                                 unconverted_result)
Exemplo n.º 30
0
    def test_if_basic(self, n):
        def test_fn(n):
            a = 0
            b = 0
            if n > 0:
                a = n
            else:
                b = n
            return a, b

        converted_fn = conversion.convert(test_fn, tf_,
                                          [variables, control_flow])

        with tf.Graph().as_default():
            with tf.Session() as sess:
                converted_result = converted_fn(tf.constant(n))
                unconverted_result = test_fn(n)
                self.assertEqual(sess.run(converted_result),
                                 unconverted_result)