def testCountingLoopHandrolledC64(self): # Define a function for the loop body @function.Defun(dtypes.int32, dtypes.complex64) def loop_body(step, rsum): step_out = step + constant_op.constant(1, dtype=dtypes.int32) sum_out = rsum + constant_op.constant(1.5 + 2j, dtype=dtypes.complex64) return step_out, sum_out # Define a function for the loop condition @function.Defun(dtypes.int32, dtypes.complex64) def loop_cond(step, rsum): del rsum return step < 10 with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) init_sum = array_ops.placeholder(dtypes.complex64, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index, init_sum], loop_cond, loop_body) result = sess.run(loop_outputs, {init_index: 0, init_sum: 0.0}) self.assertAllClose(result[1], np.complex64(15 + 20j), rtol=1e-3) no_iters_result = sess.run(loop_outputs, { init_index: 10, init_sum: 0.0 }) self.assertAllClose(no_iters_result[1], np.complex64(0), rtol=1e-3)
def testCountingLoopHandrolledC64(self): # Define a function for the loop body @function.Defun(dtypes.int32, dtypes.complex64) def loop_body(step, rsum): step_out = step + constant_op.constant(1, dtype=dtypes.int32) sum_out = rsum + constant_op.constant(1.5 + 2j, dtype=dtypes.complex64) return step_out, sum_out # Define a function for the loop condition @function.Defun(dtypes.int32, dtypes.complex64) def loop_cond(step, rsum): del rsum return step < 10 with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) init_sum = array_ops.placeholder(dtypes.complex64, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index, init_sum], loop_cond, loop_body) result = sess.run(loop_outputs, {init_index: 0, init_sum: 0.0}) self.assertAllClose(result[1], np.complex64(15 + 20j), rtol=1e-3) no_iters_result = sess.run(loop_outputs, {init_index: 10, init_sum: 0.0}) self.assertAllClose(no_iters_result[1], np.complex64(0), rtol=1e-3)
def testSingletonLoopHandrolled(self): # Define a function for the loop body @function.Defun(dtypes.int32) def loop_body(step): step_out = step + constant_op.constant(1, dtype=dtypes.int32) return step_out # Define a function for the loop condition @function.Defun(dtypes.int32) def loop_cond(step): return step < 10 with self.session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index], loop_cond, loop_body) result = sess.run(loop_outputs, {init_index: 0}) self.assertAllClose(result, [10], rtol=1e-3)
def testSingletonLoopHandrolled(self): # Define a function for the loop body @function.Defun(dtypes.int32) def loop_body(step): step_out = step + constant_op.constant(1, dtype=dtypes.int32) return step_out # Define a function for the loop condition @function.Defun(dtypes.int32) def loop_cond(step): return step < 10 with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index], loop_cond, loop_body) result = sess.run(loop_outputs, {init_index: 0}) self.assertAllClose(result, [10], rtol=1e-3)
def testLoopWithConstantOutput(self): # Define a function for the loop body @function.Defun(dtypes.int32, dtypes.int32) def loop_body(step, x): del x step_out = step + constant_op.constant(1, dtype=dtypes.int32) return (step_out, 7) # Define a function for the loop condition @function.Defun(dtypes.int32, dtypes.int32) def loop_cond(step, x): del x return step < 10 with self.session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body) result = sess.run(loop_outputs, {init_index: 0}) self.assertAllClose(result, [10, 7], rtol=1e-3)
def testLoopWithConstantOutput(self): # Define a function for the loop body @function.Defun(dtypes.int32, dtypes.int32) def loop_body(step, x): del x step_out = step + constant_op.constant(1, dtype=dtypes.int32) return (step_out, 7) # Define a function for the loop condition @function.Defun(dtypes.int32, dtypes.int32) def loop_cond(step, x): del x return step < 10 with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body) result = sess.run(loop_outputs, {init_index: 0}) self.assertAllClose(result, [10, 7], rtol=1e-3)