def testForError(self):

    @function.Defun(dtypes.int32, dtypes.float32)
    def Foo(i, v):
      return math_ops.cast(i, dtypes.float32) + v

    @function.Defun(dtypes.int32, dtypes.float32)
    def ReturnsTooManyArgs(unused_i, v):
      return v, v

    with self.test_session(use_gpu=True):
      with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                   "must be a scalar"):
        functional_ops.For([0], 10, 1, [0.0], Foo)[0].eval()
      with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                   "Invalid start/limit/delta"):
        functional_ops.For(0, 10, -1, [0.0], Foo)[0].eval()
      with self.assertRaisesRegexp(
          errors.InvalidArgumentError,
          "For loop body returned 2 arguments. Expected: 1"):
        functional_ops.For(0, 10, 1, [0.0], ReturnsTooManyArgs)[0].eval()
Esempio n. 2
0
  def testForCapturedInputs(self):
    v = variables.Variable(1.0)

    @function.Defun(dtypes.int32)
    def TestNullary(n):
      v + math_ops.cast(n, dtypes.float32)  # pylint: disable=expression-not-assigned

    @function.Defun(dtypes.int32, dtypes.float32)
    def TestUnary(n, x):
      return x + math_ops.cast(n, dtypes.float32) + v

    @function.Defun(dtypes.int32, dtypes.float32, dtypes.float32)
    def TestBinary(n, x, x2):
      return x + math_ops.cast(n, dtypes.float32) + v, x2 + v

    for rewrite_with_while in (True, False):
      use_gpu = not rewrite_with_while
      with self.test_session(use_gpu=use_gpu) as sess:
        result_nullary = functional_ops.For(
            1, 10, 1, [], TestNullary,
            rewrite_with_while=rewrite_with_while)
        result_unary = functional_ops.For(
            1, 10, 1, [0.], TestUnary,
            rewrite_with_while=rewrite_with_while)
        result_binary = functional_ops.For(
            1, 10, 1, [0., 0.], TestBinary,
            rewrite_with_while=rewrite_with_while)
        sess.run(variables.global_variables_initializer())
        assert not result_nullary
        # The nullary variant doesn't return anything so we can't easily run it.
        # As a total hack, fetch the operation by name and run it.
        sess.run(ops.get_default_graph().get_operation_by_name(
            "While" if rewrite_with_while else "For"))
        assert len(result_unary) == 1
        self.assertEqual([54.0], sess.run(result_unary))
        assert len(result_binary) == 2
        self.assertEqual([54.0, 9.0], sess.run(result_binary))
Esempio n. 3
0
  def testForWithWhileNaming(self):
    g = ops.Graph()
    with g.as_default():

      @function.Defun(dtypes.int32, dtypes.float32, func_name="TestBody")
      def TestBody(n, x):
        return x + math_ops.cast(n, dtypes.float32)

      _ = functional_ops.For(
          1, 21, 1, [0.], TestBody, rewrite_with_while=True)[0]

    names = []
    for func in g.as_graph_def().library.function:
      names.append(func.signature.name)
    self.assertTrue("TestBody" in names)
    self.assertTrue("TestBody_Cond" in names)
    self.assertTrue("TestBody_Body" in names)
Esempio n. 4
0
    def _tfMLP(self, xval, wsval, bsval, rewrite_with_while):
        # On GPU, don't rewrite using a while loop.
        use_gpu = not rewrite_with_while
        with self.test_session(use_gpu=use_gpu):

            @function.Defun(dtypes.int32, *[dtypes.float64] * 3)
            def MLP(i, a, ws, bs):
                a = math_ops.tanh(math_ops.matmul(a, ws[i, :]) + bs[i, :])
                return a, ws, bs

            ret = functional_ops.For(0,
                                     wsval.shape[0],
                                     1, [xval, wsval, bsval],
                                     MLP,
                                     rewrite_with_while=rewrite_with_while)[0]

            return ret.eval()
        def Backward(*args):
            """Backward pass for the recurrent net."""
            # theta, state0, inputs are Forward's inputs.
            # acc_state is the accumulated 1st output of Forward.
            # acc_extras is the accumulated 2nd output of Forward.
            # d_acc_state is the gradient for acc_state.
            # d_state1 is the gradient for the final state computed by Forward.
            (theta, state0, inputs, max_input_length, acc_state, acc_extras,
             d_acc_state, d_state1) = _Pack(args, backward_sig)

            # Accumulators for gradients.
            d_theta = _EmptyLike(theta)
            d_inputs = _EmptyLike(inputs)

            slen_dim = _SeqLenDim(inputs)

            # Loop backwards. Note the loop's limit is open-ended, so goes through
            # t=0.
            t = slen_dim - 1 if self._aligned_end else max_input_length - 1
            dev_t = math_ops.cast(t,
                                  dtypes.int32) if use_tpu else math_ops.cast(
                                      t, dtypes.int64)
            limit = slen_dim - max_input_length - 1 if self._aligned_end else -1
            run = functional_ops.For(
                start=t,
                limit=limit,
                delta=-1,
                inputs=[dev_t] + _Flatten([
                    theta, state0, inputs, acc_state, acc_extras, d_theta,
                    d_state1, d_inputs, d_acc_state
                ]),
                body=BackwardLoopBody,
                rewrite_with_while=compiled)

            (theta, state0, inputs, acc_state, acc_extras, d_theta, d_state0,
             d_inputs, d_acc_state) = _Pack(run[1:], bakloop_sig)

            d_max_input_length = array_ops.constant(
                0, dtype=max_input_length.dtype)
            return _Flatten(
                [d_theta, d_state0, d_inputs, d_max_input_length, acc_extras])
    def Forward(*args):
      """Forward pass of the recurrent net."""
      theta, state0, inputs, max_input_length, extras = _Pack(args, forward_sig)

      slen_dim = _SeqLenDim(inputs)

      # Creates accumulators for state0 and extras.
      acc_state = _EmptyAcc(slen_dim, state0)
      acc_extras = _EmptyAcc(slen_dim, extras)

      dev_t = array_ops.constant(0, dtype=dev_t_type)
      run = functional_ops.For(
          start=0,
          limit=max_input_length,
          delta=1,
          inputs=[dev_t] + _Flatten(
              [theta, state0, inputs, acc_state, acc_extras]),
          body=ForwardLoopBody,
          rewrite_with_while=compiled)
      _, state1, _, acc_state, acc_extras = _Pack(
          run[1:],
          [self._theta, self._state, self._inputs, self._state, self._extras])

      return _Flatten([acc_state, state1, acc_extras])