def _torch_to_tf(self, batch_size, max_seq_len): """Benchmark pytorch converted to tensorflow graph implementation.""" tf_dynamic_rnn = conversion.convert( torch_dynamic_rnn, pytorch, [functions, variables, control_flow]) tf_create_rnn_cell = conversion.convert( _create_torch_rnn_cell, pytorch, [functions, variables, control_flow]) tf_get_init_data = conversion.convert( _get_torch_inputs, pytorch, [functions, variables, control_flow]) with tf.Graph().as_default(): input_data, sequence_lengths = self._generate_fake_rnn_inputs( batch_size=batch_size, max_seq_len=max_seq_len) cell, init_state = tf_create_rnn_cell(batch_size) input_data, sequence_lengths = tf_get_init_data( input_data, sequence_lengths) rnn_output = tf_dynamic_rnn(cell, input_data, init_state, sequence_lengths) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) def target(): sess.run(rnn_output) self.time_execution(('Torch2Flow', batch_size, max_seq_len), target, iter_volume=batch_size, iter_unit='examples', extras={ 'max_seq_len': max_seq_len, 'batch_size': batch_size, })
def _numpy_to_tf(self, batch_size, max_seq_len, input_size, hidden_size): with tf.Graph().as_default(): with tf.Session() as sess: tensors = dynamic_rnn_minimal.random_inputs_tf( batch_size, max_seq_len, input_size, hidden_size) inputs, seq_len, w, b, init_state = tensors tf_from_np = conversion.convert( dynamic_rnn_minimal.numpy, numpy_to_tf, [variables, functions, control_flow]) ops = tf_from_np(inputs, seq_len, w, b, init_state) def target(): sess.run(ops) self.time_execution( ('NumPy_TF', batch_size, max_seq_len, input_size, hidden_size), target, extras={ 'max_seq_len': max_seq_len, 'batch_size': batch_size, 'input_size': input_size, 'hidden_size': hidden_size, })
def test_overload_not_staged(self): """Calls on the overload module are not themselves overloaded.""" # After conversion this function will contain function calls for variable # virtualization, we do not want these to be affected by function call # virtualization. def test_fn(y): x = 1 + y return x # Custom overloads that blow up on call virtualization. class MyOverloads(object): # There is no call in the test_fn that should be virtualized this ensures # that an exception is thrown if call virtualization is attempted. def call(self, func, args, keywords): # pylint: disable=unused-argument self.fail('No call should be virtualized in this test.') def init(self, name): return py_defaults.init(name) def assign(self, lhs, rhs): return py_defaults.assign(lhs, rhs) def read(self, var): return py_defaults.read(var) overloads = MyOverloads() converted_func = conversion.convert(test_fn, overloads, [variables, functions]) self.assertEqual(converted_func(5), 6)
def test_sequential_ifs(self, x): def test_fn(x): a = 0 b = 0 if x > 0: a = 1 else: b = 1 if x < 0: a = 2 else: b = 2 return a, b converted_fn = conversion.convert(test_fn, tf_, [variables, control_flow]) with tf.Graph().as_default(): with tf.Session() as sess: converted_result = converted_fn(tf.constant(x)) unconverted_result = test_fn(x) self.assertEqual(sess.run(converted_result), unconverted_result)
def test_noop_dictionary_variables(self): self.assertListEqual(check_cond(1), [1]) self.assertListEqual(check_cond(5), [2]) converted_check_cond = conversion.convert(check_cond, dictionary_variables, []) self.assertListEqual(converted_check_cond(1), [1]) self.assertListEqual(check_cond(5), [2])
def test_noop(self): self.assertListEqual(check_cond(1), [1]) self.assertListEqual(check_cond(5), [2]) converted_check_cond = conversion.convert(check_cond, py_defaults, []) self.assertListEqual(converted_check_cond(1), [1]) self.assertListEqual(check_cond(5), [2])
def test_nested_if(self, x, y): def test_fn(x, y): a = 0 b = 0 c = 0 d = 0 if x > 0: if y > 0: a = x else: b = y else: if y > 0: c = x else: d = y return a, b, c, d converted_fn = conversion.convert(test_fn, tf_, [variables, control_flow]) with tf.Graph().as_default(): with tf.Session() as sess: converted_result = converted_fn(tf.constant(x), tf.constant(y)) unconverted_result = test_fn(x, y) self.assertEqual(sess.run(converted_result), unconverted_result)
def test_basic_control_flow(self, x): def test_fn(n): i = 0 s = 0 while i < n: if i > s: u = i - s j = 0 else: u = i + s j = 1 s = s + u + j i = i + 1 return s, i converted_fn = conversion.convert(test_fn, pytorch, [variables, control_flow]) with tf.Graph().as_default(): with tf.Session() as sess: converted_result = converted_fn( torch.tensor(x, dtype=torch.int32)) unconverted_result = test_fn(x) self.assertEqual(sess.run(converted_result), unconverted_result)
def test_eight_queens(self): # See https://ericpony.github.io/z3py-tutorial/guide-examples.htm def test_fn(queens): diagonals = [] for i in range(8): for j in range(i): result = None if i == j: result = True else: result = queens[i] - queens[j] != i - j and queens[ i] - queens[j] != j - 1 diagonals.append(result) return diagonals queens = [z3.Int('queens_%i' % (i + 1)) for i in range(8)] ranks = [z3.And(1 <= queens[i], queens[i] <= 8) for i in range(8)] files = [z3.Distinct(queens)] converted_fn = conversion.convert( test_fn, z3py, [logical_ops, variables, control_flow]) diagonals = converted_fn(queens) self.assertTrue(can_solve(ranks + files + diagonals))
def test_nested_cond(self, x): def fun(x): res = 0 if x < 2: res = lax.mul(2, x) else: if x < 5: res = lax.mul(3, x) else: res = lax.mul(4, x) return res @jax.api.jit def cfun(x): def inner_cond(x): return lax.cond( lax.lt(x, 5), x, lambda x: lax.mul(3, x), 4, lambda y: lax.mul(y, x), ) return lax.cond(lax.lt(x, 2), x, lambda x: lax.mul(2, x), x, inner_cond) converted_fn = conversion.convert(fun, jax_, [variables, control_flow]) self.assertEqual(cfun(x), converted_fn(x))
def test_for_multiple_targets(self, iter_): def test_fn(iter_): res = [] for x, y in iter_: res.append((x, y)) return res converted_fn = conversion.convert(test_fn, py_defaults, [variables]) self.assertEqual(converted_fn(iter_), test_fn(iter_))
def test_for_loop(self, x): def test_fn(n): sum_ = 0 for i in range(n): sum_ = sum_ + i return sum_ converted_fn = conversion.convert(test_fn, py_defaults, [variables]) self.assertEqual(converted_fn(x), test_fn(x))
def test_known_function_swapped(self): # add is overloaded in call_swapping to return x + y + 1 for ints def test_fn(x, y): return call_swapping.add(x, y) converted_func = conversion.convert(test_fn, call_swapping, [functions]) self.assertEqual(converted_func(1, 2), test_fn(1, 2) + 1) self.assertEqual(converted_func('hello ', 'world'), test_fn('hello ', 'world'))
def test_for_loop_return_target(self, x): def test_fn(n): sum_ = 0 for i in range(n): sum_ = sum_ + i return sum_ + i # pylint: disable=undefined-loop-variable converted_fn = conversion.convert(test_fn, py_defaults, [variables]) self.assertEqual(converted_fn(x), test_fn(x))
def test_for_loop_return_target_le_0(self, x): def test_fn(n): sum_ = 0 for i in range(n): sum_ = sum_ + i return sum_ + i # pylint: disable=undefined-loop-variable converted_fn = conversion.convert(test_fn, py_defaults, [variables]) with self.assertRaises(py_defaults.PyctUnboundLocalError): converted_fn(x)
def test_if(self, p, q, r): def test_fn(a, b, c): result = None if a: result = b else: result = c return result converted_fn = conversion.convert(test_fn, z3py, [variables, control_flow]) self.assertTrue(prove(z3.If(p, q, r) == converted_fn(p, q, r)))
def test_if_no_else(self): def test_fn(n): v = [] if n > 0: v.append(n) return v converted_fn = conversion.convert(test_fn, py_defaults, [control_flow]) for i in [0, 1]: self.assertEqual(converted_fn(i), test_fn(i))
def test_for_parameterized_noop(self, iter_): def test_fn(iter_): res = [] for x, y in iter_: res.append((x, y)) return res converted_fn = conversion.convert(test_fn, py_defaults, [variables, control_flow]) self.assertEqual(converted_fn(iter_), test_fn(iter_))
def test_for_noop(self, x): def test_fn(n): res = [] for i in range(n): res.append(i) return res converted_fn = conversion.convert(test_fn, py_defaults, [variables, control_flow]) self.assertEqual(converted_fn(x), test_fn(x))
def test_unknown_function_not_swapped(self): # no overload for adder exists in call_swapping def adder(x, y): return x + y def test_fn(x, y): return adder(x, y) converted_func = conversion.convert(test_fn, call_swapping, [functions]) self.assertEqual(converted_func(1, 2), test_fn(1, 2)) self.assertEqual(converted_func('hello ', 'world'), test_fn('hello ', 'world'))
def test_noop_while(self): def test_fn(n): res = [] i = 0 while i < n: res.append(i) i = i + 1 return res converted_fn = conversion.convert(test_fn, py_defaults, [variables, control_flow]) self.assertEqual(converted_fn(5), test_fn(5))
def test_while_basic(self, i): limit = 10 def test_fn(init): count = 0 while init < limit: init = init + 1 count = count + 1 return count converted_fn = conversion.convert(test_fn, jax_, [variables, control_flow]) jitted_fn = jax.api.jit(converted_fn) self.assertEqual(test_fn(i), limit - i) self.assertEqual(test_fn(i), converted_fn(i)) self.assertEqual(test_fn(i), jitted_fn(i))
def test_jit_after_conversion_cond(self): def f(x): res = 0 if x < 3: res = 3. * x**2 else: res = -4. * x return res converted_fn = conversion.convert(f, jax_, [variables, control_flow]) self.assertEqual(f(2), converted_fn(2)) jitted_fn = jax.api.jit(converted_fn) self.assertEqual(f(2), jitted_fn(2))
def test_very_nested_if(self, x): def test_fn(x): a = 0 b = 0 c = 0 d = 0 e = 0 f = 0 g = 0 h = 0 then_branch = 0 else_branch = 0 if x > 0: if x > 2: if x > 4: a = 1 else: b = 1 else: if x > 1: c = 1 else: d = 1 then_branch = 1 else: if x < -2: if x < -4: e = 1 else: f = 1 else: if x < -1: g = 1 else: h = 1 else_branch = 1 return a, b, c, d, e, f, g, h, then_branch, else_branch converted_fn = conversion.convert(test_fn, tf_, [variables, control_flow]) with tf.Graph().as_default(): with tf.Session() as sess: converted_result = converted_fn(tf.constant(x)) unconverted_result = test_fn(x) self.assertEqual(sess.run(converted_result), unconverted_result)
def test_if_basic(self, x): def test_fn(n): a = 0 b = 0 if n > 0: a = n else: b = n return a, b converted_fn = conversion.convert(test_fn, jax_, [variables, control_flow]) jax_converted = jax.jit(converted_fn) converted_result = jax_converted(x) unconverted_result = test_fn(x) self.assertEqual(converted_result, unconverted_result)
def test_if_tuple(self, p, q, r): def test_fn(a, b, c): result = None test_result = None if a: test_result = c result = b else: result = c test_result = b return result, test_result converted_fn = conversion.convert(test_fn, z3py, [variables, control_flow]) a, b = converted_fn(p, q, r) self.assertTrue( prove(z3.If(p, z3.And(q, r), z3.And(q, r)) == z3.And(a, b)))
def test_very_nested_if(self, x): def test_fn(x): a = 0 b = 0 c = 0 d = 0 e = 0 f = 0 g = 0 h = 0 then_branch = 0 else_branch = 0 if x > 0: if x > 2: if x > 4: a = 1 else: b = 1 else: if x > 1: c = 1 else: d = 1 then_branch = 1 else: if x < -2: if x < -4: e = 1 else: f = 1 else: if x < -1: g = 1 else: h = 1 else_branch = 1 return a, b, c, d, e, f, g, h, then_branch, else_branch converted_fn = conversion.convert(test_fn, jax_, [variables, control_flow]) jax_converted = jax.jit(converted_fn) converted_result = jax_converted(x) unconverted_result = test_fn(x) self.assertEqual(converted_result, unconverted_result)
def test_eight_queens_optimized(self): def test_fn(): queens = [z3.Int('queens_%i' % (i + 1)) for i in range(8)] ranks = [1 <= queens[i] and queens[i] <= 8 for i in range(8)] files = [z3.Distinct(queens)] diagonals = [] for i in range(8): for j in range(i): if i != j: diagonals.append( abs(queens[i] - queens[j]) != abs(i - j)) return ranks, files, diagonals converted_fn = conversion.convert(test_fn, z3py, [logical_ops, functions]) ranks, files, diagonals = converted_fn() self.assertTrue(can_solve(ranks + files + diagonals))
def test_while_basic(self, x): def test_fn(n): i = 0 s = 0 while i < n: s = s + i i = i + 1 return s, i, n converted_fn = conversion.convert(test_fn, tf_, [variables, control_flow]) with tf.Graph().as_default(): with tf.Session() as sess: converted_result = converted_fn(tf.constant(x)) unconverted_result = test_fn(x) self.assertEqual(sess.run(converted_result), unconverted_result)
def test_if_basic(self, n): def test_fn(n): a = 0 b = 0 if n > 0: a = n else: b = n return a, b converted_fn = conversion.convert(test_fn, tf_, [variables, control_flow]) with tf.Graph().as_default(): with tf.Session() as sess: converted_result = converted_fn(tf.constant(n)) unconverted_result = test_fn(n) self.assertEqual(sess.run(converted_result), unconverted_result)