def testForError(self): @function.Defun(dtypes.int32, dtypes.float32) def Foo(i, v): return math_ops.cast(i, dtypes.float32) + v @function.Defun(dtypes.int32, dtypes.float32) def ReturnsTooManyArgs(unused_i, v): return v, v with self.test_session(use_gpu=True): with self.assertRaisesRegexp(errors.InvalidArgumentError, "must be a scalar"): functional_ops.For([0], 10, 1, [0.0], Foo)[0].eval() with self.assertRaisesRegexp(errors.InvalidArgumentError, "Invalid start/limit/delta"): functional_ops.For(0, 10, -1, [0.0], Foo)[0].eval() with self.assertRaisesRegexp( errors.InvalidArgumentError, "For loop body returned 2 arguments. Expected: 1"): functional_ops.For(0, 10, 1, [0.0], ReturnsTooManyArgs)[0].eval()
def testForCapturedInputs(self): v = variables.Variable(1.0) @function.Defun(dtypes.int32) def TestNullary(n): v + math_ops.cast(n, dtypes.float32) # pylint: disable=expression-not-assigned @function.Defun(dtypes.int32, dtypes.float32) def TestUnary(n, x): return x + math_ops.cast(n, dtypes.float32) + v @function.Defun(dtypes.int32, dtypes.float32, dtypes.float32) def TestBinary(n, x, x2): return x + math_ops.cast(n, dtypes.float32) + v, x2 + v for rewrite_with_while in (True, False): use_gpu = not rewrite_with_while with self.test_session(use_gpu=use_gpu) as sess: result_nullary = functional_ops.For( 1, 10, 1, [], TestNullary, rewrite_with_while=rewrite_with_while) result_unary = functional_ops.For( 1, 10, 1, [0.], TestUnary, rewrite_with_while=rewrite_with_while) result_binary = functional_ops.For( 1, 10, 1, [0., 0.], TestBinary, rewrite_with_while=rewrite_with_while) sess.run(variables.global_variables_initializer()) assert not result_nullary # The nullary variant doesn't return anything so we can't easily run it. # As a total hack, fetch the operation by name and run it. sess.run(ops.get_default_graph().get_operation_by_name( "While" if rewrite_with_while else "For")) assert len(result_unary) == 1 self.assertEqual([54.0], sess.run(result_unary)) assert len(result_binary) == 2 self.assertEqual([54.0, 9.0], sess.run(result_binary))
def testForWithWhileNaming(self): g = ops.Graph() with g.as_default(): @function.Defun(dtypes.int32, dtypes.float32, func_name="TestBody") def TestBody(n, x): return x + math_ops.cast(n, dtypes.float32) _ = functional_ops.For( 1, 21, 1, [0.], TestBody, rewrite_with_while=True)[0] names = [] for func in g.as_graph_def().library.function: names.append(func.signature.name) self.assertTrue("TestBody" in names) self.assertTrue("TestBody_Cond" in names) self.assertTrue("TestBody_Body" in names)
def _tfMLP(self, xval, wsval, bsval, rewrite_with_while): # On GPU, don't rewrite using a while loop. use_gpu = not rewrite_with_while with self.test_session(use_gpu=use_gpu): @function.Defun(dtypes.int32, *[dtypes.float64] * 3) def MLP(i, a, ws, bs): a = math_ops.tanh(math_ops.matmul(a, ws[i, :]) + bs[i, :]) return a, ws, bs ret = functional_ops.For(0, wsval.shape[0], 1, [xval, wsval, bsval], MLP, rewrite_with_while=rewrite_with_while)[0] return ret.eval()
def Backward(*args): """Backward pass for the recurrent net.""" # theta, state0, inputs are Forward's inputs. # acc_state is the accumulated 1st output of Forward. # acc_extras is the accumulated 2nd output of Forward. # d_acc_state is the gradient for acc_state. # d_state1 is the gradient for the final state computed by Forward. (theta, state0, inputs, max_input_length, acc_state, acc_extras, d_acc_state, d_state1) = _Pack(args, backward_sig) # Accumulators for gradients. d_theta = _EmptyLike(theta) d_inputs = _EmptyLike(inputs) slen_dim = _SeqLenDim(inputs) # Loop backwards. Note the loop's limit is open-ended, so goes through # t=0. t = slen_dim - 1 if self._aligned_end else max_input_length - 1 dev_t = math_ops.cast(t, dtypes.int32) if use_tpu else math_ops.cast( t, dtypes.int64) limit = slen_dim - max_input_length - 1 if self._aligned_end else -1 run = functional_ops.For( start=t, limit=limit, delta=-1, inputs=[dev_t] + _Flatten([ theta, state0, inputs, acc_state, acc_extras, d_theta, d_state1, d_inputs, d_acc_state ]), body=BackwardLoopBody, rewrite_with_while=compiled) (theta, state0, inputs, acc_state, acc_extras, d_theta, d_state0, d_inputs, d_acc_state) = _Pack(run[1:], bakloop_sig) d_max_input_length = array_ops.constant( 0, dtype=max_input_length.dtype) return _Flatten( [d_theta, d_state0, d_inputs, d_max_input_length, acc_extras])
def Forward(*args): """Forward pass of the recurrent net.""" theta, state0, inputs, max_input_length, extras = _Pack(args, forward_sig) slen_dim = _SeqLenDim(inputs) # Creates accumulators for state0 and extras. acc_state = _EmptyAcc(slen_dim, state0) acc_extras = _EmptyAcc(slen_dim, extras) dev_t = array_ops.constant(0, dtype=dev_t_type) run = functional_ops.For( start=0, limit=max_input_length, delta=1, inputs=[dev_t] + _Flatten( [theta, state0, inputs, acc_state, acc_extras]), body=ForwardLoopBody, rewrite_with_while=compiled) _, state1, _, acc_state, acc_extras = _Pack( run[1:], [self._theta, self._state, self._inputs, self._state, self._extras]) return _Flatten([acc_state, state1, acc_extras])