Exemplo n.º 1
0
    def testCountingLoopHandrolledC64(self):
        # Define a function for the loop body
        @function.Defun(dtypes.int32, dtypes.complex64)
        def loop_body(step, rsum):
            step_out = step + constant_op.constant(1, dtype=dtypes.int32)
            sum_out = rsum + constant_op.constant(1.5 + 2j,
                                                  dtype=dtypes.complex64)
            return step_out, sum_out

        # Define a function for the loop condition
        @function.Defun(dtypes.int32, dtypes.complex64)
        def loop_cond(step, rsum):
            del rsum
            return step < 10

        with self.cached_session() as sess:
            init_index = array_ops.placeholder(dtypes.int32, [])
            init_sum = array_ops.placeholder(dtypes.complex64, [])
            with self.test_scope():
                loop_outputs = xla.while_loop([init_index, init_sum],
                                              loop_cond, loop_body)

            result = sess.run(loop_outputs, {init_index: 0, init_sum: 0.0})
            self.assertAllClose(result[1], np.complex64(15 + 20j), rtol=1e-3)
            no_iters_result = sess.run(loop_outputs, {
                init_index: 10,
                init_sum: 0.0
            })
            self.assertAllClose(no_iters_result[1], np.complex64(0), rtol=1e-3)
Exemplo n.º 2
0
  def testCountingLoopHandrolledC64(self):
    # Define a function for the loop body
    @function.Defun(dtypes.int32, dtypes.complex64)
    def loop_body(step, rsum):
      step_out = step + constant_op.constant(1, dtype=dtypes.int32)
      sum_out = rsum + constant_op.constant(1.5 + 2j, dtype=dtypes.complex64)
      return step_out, sum_out

    # Define a function for the loop condition
    @function.Defun(dtypes.int32, dtypes.complex64)
    def loop_cond(step, rsum):
      del rsum
      return step < 10

    with self.cached_session() as sess:
      init_index = array_ops.placeholder(dtypes.int32, [])
      init_sum = array_ops.placeholder(dtypes.complex64, [])
      with self.test_scope():
        loop_outputs = xla.while_loop([init_index, init_sum], loop_cond,
                                      loop_body)

      result = sess.run(loop_outputs, {init_index: 0, init_sum: 0.0})
      self.assertAllClose(result[1], np.complex64(15 + 20j), rtol=1e-3)
      no_iters_result = sess.run(loop_outputs, {init_index: 10, init_sum: 0.0})
      self.assertAllClose(no_iters_result[1], np.complex64(0), rtol=1e-3)
Exemplo n.º 3
0
  def testSingletonLoopHandrolled(self):
    # Define a function for the loop body
    @function.Defun(dtypes.int32)
    def loop_body(step):
      step_out = step + constant_op.constant(1, dtype=dtypes.int32)
      return step_out

    # Define a function for the loop condition
    @function.Defun(dtypes.int32)
    def loop_cond(step):
      return step < 10

    with self.session() as sess:
      init_index = array_ops.placeholder(dtypes.int32, [])
      with self.test_scope():
        loop_outputs = xla.while_loop([init_index], loop_cond, loop_body)

      result = sess.run(loop_outputs, {init_index: 0})
      self.assertAllClose(result, [10], rtol=1e-3)
Exemplo n.º 4
0
  def testSingletonLoopHandrolled(self):
    # Define a function for the loop body
    @function.Defun(dtypes.int32)
    def loop_body(step):
      step_out = step + constant_op.constant(1, dtype=dtypes.int32)
      return step_out

    # Define a function for the loop condition
    @function.Defun(dtypes.int32)
    def loop_cond(step):
      return step < 10

    with self.cached_session() as sess:
      init_index = array_ops.placeholder(dtypes.int32, [])
      with self.test_scope():
        loop_outputs = xla.while_loop([init_index], loop_cond, loop_body)

      result = sess.run(loop_outputs, {init_index: 0})
      self.assertAllClose(result, [10], rtol=1e-3)
Exemplo n.º 5
0
  def testLoopWithConstantOutput(self):
    # Define a function for the loop body
    @function.Defun(dtypes.int32, dtypes.int32)
    def loop_body(step, x):
      del x
      step_out = step + constant_op.constant(1, dtype=dtypes.int32)
      return (step_out, 7)

    # Define a function for the loop condition
    @function.Defun(dtypes.int32, dtypes.int32)
    def loop_cond(step, x):
      del x
      return step < 10

    with self.session() as sess:
      init_index = array_ops.placeholder(dtypes.int32, [])
      with self.test_scope():
        loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body)

      result = sess.run(loop_outputs, {init_index: 0})
      self.assertAllClose(result, [10, 7], rtol=1e-3)
Exemplo n.º 6
0
  def testLoopWithConstantOutput(self):
    # Define a function for the loop body
    @function.Defun(dtypes.int32, dtypes.int32)
    def loop_body(step, x):
      del x
      step_out = step + constant_op.constant(1, dtype=dtypes.int32)
      return (step_out, 7)

    # Define a function for the loop condition
    @function.Defun(dtypes.int32, dtypes.int32)
    def loop_cond(step, x):
      del x
      return step < 10

    with self.cached_session() as sess:
      init_index = array_ops.placeholder(dtypes.int32, [])
      with self.test_scope():
        loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body)

      result = sess.run(loop_outputs, {init_index: 0})
      self.assertAllClose(result, [10, 7], rtol=1e-3)