コード例 #1
0
  def testWhileError(self):
    for use_gpu in (True, False):
      with ops.Graph().as_default() as g:

        @function.Defun(*[dtypes.float32] * 2)
        def Cond(n, unused_x):
          return n > 0

        @function.Defun(*[dtypes.float32] * 2)
        def CondReturnsTooManyArgs(n, x):
          return n > 0, x

        @function.Defun(*[dtypes.float32] * 2)
        def Body(n, x):
          return n - 1, x + n

        @function.Defun(*[dtypes.float32] * 2)
        def BodyReturnsTooManyArgs(n, x):
          return n - 1, x + n, x

        with self.session(graph=g, use_gpu=use_gpu):
          with self.assertRaisesRegexp(
              errors.InvalidArgumentError,
              "Expected a single scalar.*got 2 tensors."):
            functional_ops.While([5., 0.], CondReturnsTooManyArgs,
                                 Body)[0].eval()
          with self.assertRaisesRegexp(
              errors.InvalidArgumentError,
              "While loop body returned 3 arguments. Expected: 2"):
            functional_ops.While([5., 0.], Cond,
                                 BodyReturnsTooManyArgs)[0].eval()
コード例 #2
0
    def testWhileError(self):
        @function.Defun(*[dtypes.float32] * 2)
        def Cond(n, unused_x):
            return n > 0

        @function.Defun(*[dtypes.float32] * 2)
        def CondReturnsTooManyArgs(n, x):
            return n > 0, x

        @function.Defun(*[dtypes.float32] * 2)
        def Body(n, x):
            return n - 1, x + n

        @function.Defun(*[dtypes.float32] * 2)
        def BodyReturnsTooManyArgs(n, x):
            return n - 1, x + n, x

        # TODO(b/65752372): Set `use_gpu=False` because
        # `functional_ops.While()` does not reliably work on GPU (apparently
        # because the result of evaluating the condition may be in device
        # memory, but it is read on the host).
        with self.test_session(use_gpu=False):
            with self.assertRaisesRegexp(
                    errors.InvalidArgumentError,
                    "Expected a single scalar.*got 2 tensors."):
                functional_ops.While([5., 0.], CondReturnsTooManyArgs,
                                     Body)[0].eval()
            with self.assertRaisesRegexp(
                    errors.InvalidArgumentError,
                    "While loop body returned 3 arguments. Expected: 2"):
                functional_ops.While([5., 0.], Cond,
                                     BodyReturnsTooManyArgs)[0].eval()
コード例 #3
0
    def Run(n, fetch_by_name):
      for use_gpu in (True, False):
        with ops.Graph().as_default() as g:

          @function.Defun(*[dtypes.float32] * 2)
          def Cond(n, unused_x):
            return n > 0

          @function.Defun(*[dtypes.float32] * 2)
          def Body(n, x):
            return n - 1, x + n

          # outputs: [0, n*(n+1)/2]
          outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while")

          # `outputs` is the list of output tensors of the While op. We
          # arbitrarily choose the 0th tensor to get the While op and set the
          # lowering attribute on it.
          outputs[0].op._set_attr("_lower_using_switch_merge",
                                  attr_value_pb2.AttrValue(b=True))
          if not fetch_by_name:
            fetch = outputs[1]
          else:
            fetch = "my_while:1"
        with self.session(graph=g, use_gpu=use_gpu) as sess:
          return sess.run(fetch)
コード例 #4
0
    def testWhileInMultipleSubgraphs(self):

        for use_gpu in (True, False):
            with ops.Graph().as_default() as g:

                @function.Defun(*[dtypes.float32] * 2)
                def Cond(n, x):  # pylint: disable=unused-argument
                    return n > 0

                @function.Defun(*[dtypes.float32] * 2)
                def Body(n, x):
                    return n - 1, x + n

                with self.test_session(graph=g, use_gpu=use_gpu) as sess:
                    n = array_ops.placeholder(dtypes.float32)
                    _, result = functional_ops.While([n, 0.], Cond, Body)
                    c = constant_op.constant(37.)

                    self.assertAllEqual(210.,
                                        sess.run(result, feed_dict={n: 20.}))
                    self.assertAllEqual(5050.,
                                        sess.run(result, feed_dict={n: 100.}))
                    # Test that the result is the same when we run a different subgraph.
                    self.assertAllEqual(
                        5050.,
                        sess.run([result, c], feed_dict={n: 100.})[0])
コード例 #5
0
    def testWhileCapturedInputs(self):
        for use_gpu in (True, False):
            with ops.Graph().as_default() as g:
                v = variables.Variable(1.0)

                def TestCond(n, *args):
                    del args
                    return n < 10

                @function.Defun(*[dtypes.float32] * 2)
                def TestUnary(n, x):
                    return math_ops.add(n, 1), x + n + v

                @function.Defun(*[dtypes.float32] * 3)
                def TestBinary(n, x, x2):
                    return math_ops.add(n, 1), x + n + v, x2 + v

                with self.session(graph=g, use_gpu=use_gpu) as sess:
                    result_unary = functional_ops.While(
                        [1.0, 0.],
                        function.Defun(*[dtypes.float32] * 2)(TestCond),
                        TestUnary)
                    result_binary = functional_ops.While(
                        [1.0, 0., 0.],
                        function.Defun(*[dtypes.float32] * 3)(TestCond),
                        TestBinary)
                    self.evaluate(variables.global_variables_initializer())
                    assert len(result_unary) == 2
                    self.assertEqual([10.0, 54.0], self.evaluate(result_unary))
                    assert len(result_binary) == 3
                    self.assertEqual([10.0, 54.0, 9.0],
                                     self.evaluate(result_binary))

                    def TestCondCapture(n, *args):
                        del args
                        return math_ops.to_float(n) + v < 10

                    with self.assertRaises(ValueError):
                        _ = functional_ops.While(
                            [1],
                            function.Defun(dtypes.int32)(TestCondCapture),
                            function.Defun(dtypes.int32,
                                           dtypes.float32)(TestUnary))
コード例 #6
0
ファイル: recurrent.py プロジェクト: zhoudaqing/lingvo
        def Backward(start, *args):
            """Backward pass for the recurrent net."""
            # theta, state0, inputs are Forward's inputs.
            # acc_state is the accumulated 1st output of Forward.
            # acc_extras is the accumulated 2nd output of Forward.
            # d_acc_state is the gradient for acc_state.
            # d_state1 is the gradient for the final state computed by Forward.
            (theta, state0, inputs, acc_state, acc_extras, d_acc_state,
             d_state1) = _Pack(args, backward_sig)

            # Accumulators for gradients.
            d_theta = _EmptyLike(theta)
            d_inputs = _EmptyLike(inputs)
            d_captured = _EmptyLike(self._implicit_captures)

            # The sequence length.
            pad_begin, _ = _SeqPaddingLength(inputs)
            limit = pad_begin

            if compiled:
                limit = tf.to_int32(limit)
            else:
                limit = tf.to_int64(limit)
            run = functional_ops.While([start - 1, limit] + _Flatten([
                theta,
                state0,
                inputs,
                acc_state,
                acc_extras,
                d_theta,
                d_state1,
                d_inputs,
                d_acc_state,
                d_captured,
            ]),
                                       cond=BackwardLoopCond,
                                       body=BackwardLoopBody)

            (theta, state0, inputs, acc_state, acc_extras, d_theta, d_state0,
             d_inputs, d_acc_state, d_captured) = _Pack(run[2:], bakloop_sig)

            # Make sure this function didn't capture anything different than the
            # cell_fn when reflected on at the beginning. Must come after the
            # call to BackwardLoopBody, which adds to the captured list.
            _AssertSameTensors(function.get_extra_inputs(),
                               self._implicit_captures.Flatten())

            if self._unused_acc_state:
                # Match the shape of gradient of the init_state.
                d_state0 = self._state.Transform(tf.zeros_like)
            return _Flatten(
                [d_theta, d_state0, d_inputs, acc_extras, d_captured])
コード例 #7
0
    def testWhileInMultipleSubgraphs(self):
        @function.Defun(*[dtypes.float32] * 2)
        def Cond(n, x):  # pylint: disable=unused-argument
            return n > 0

        @function.Defun(*[dtypes.float32] * 2)
        def Body(n, x):
            return n - 1, x + n

        # TODO(b/65752372): Set `use_gpu=False` because
        # `functional_ops.While()` does not reliably work on GPU (apparently
        # because the result of evaluating the condition may be in device
        # memory, but it is read on the host).
        with self.test_session(use_gpu=False) as sess:
            n = array_ops.placeholder(dtypes.float32)
            _, result = functional_ops.While([n, 0.], Cond, Body)
            c = constant_op.constant(37.)

            self.assertAllEqual(210., sess.run(result, feed_dict={n: 20.}))
            self.assertAllEqual(5050., sess.run(result, feed_dict={n: 100.}))
            # Test that the result is the same when we run a different subgraph.
            self.assertAllEqual(5050.,
                                sess.run([result, c], feed_dict={n: 100.})[0])
コード例 #8
0
ファイル: recurrent.py プロジェクト: zhoudaqing/lingvo
        def Forward(*args):
            """Forward pass of the recurrent net."""
            theta, state0, inputs, extras = _Pack(args, forward_sig)

            # The sequence length.
            pad_begin, pad_end = _SeqPaddingLength(inputs)
            slen_dim = _SeqLenDim(inputs)
            limit = slen_dim - pad_end

            # Creates accumulators for state0 and extras.
            if self._unused_acc_state:
                acc_state = _EmptyWithFixShape([slen_dim], state0)
            else:
                acc_state = _EmptyAcc(slen_dim, state0)
            acc_extras = _EmptyAcc(slen_dim, extras)

            if compiled:
                t = tf.to_int32(pad_begin)
                limit = tf.to_int32(limit)
            else:
                t = tf.to_int64(pad_begin)
                limit = tf.to_int64(limit)

            run = functional_ops.While(
                [t, limit] +
                _Flatten([theta, state0, inputs, acc_state, acc_extras]),
                cond=ForwardLoopCond,
                body=ForwardLoopBody)
            t = run[0]
            _, state1, _, acc_state, acc_extras = _Pack(
                run[2:], [
                    self._theta, self._state, self._inputs, self._state,
                    self._extras
                ])

            return [t] + _Flatten([acc_state, state1, acc_extras])
コード例 #9
0
ファイル: ctc_ops.py プロジェクト: flavz27/master_PA
def _scan(fn,
          elems,
          initial,
          reverse=False,
          inclusive=False,
          final_only=False):
    """Repeatedly applies callable `fn` to a sequence of elements.

  Implemented by functional_ops.While, tpu friendly, no gradient.

  This is similar to functional_ops.scan but significantly faster on tpu/gpu
  for the forward backward use case.

  Examples:
    scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 4.0, 7.0]

    Multiple accumulators:
      scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0))

    Multiple inputs:
      scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0)

  Args:
    fn: callable, fn(accumulators, element) return new accumulator values. The
      (possibly nested) sequence of accumulators is the same as `initial` and
      the return value must have the same structure.
    elems: A (possibly nested) tensor which will be unpacked along the first
      dimension. The resulting slices will be the second argument to fn. The
      first dimension of all nested input tensors must be the same.
    initial: A tensor or (possibly nested) sequence of tensors with initial
      values for the accumulators.
    reverse: (optional) True enables scan and output elems in reverse order.
    inclusive: (optional) True includes the initial accumulator values in the
      output. Length of output will be len(elem sequence) + 1. Not meaningful if
      final_only is True.
    final_only: (optional) When True, return only the final accumulated values,
      not the concatenation of accumulated values for each input.

  Returns:
    A (possibly nested) sequence of tensors with the results of applying fn
    to tensors unpacked from elems and previous accumulator values.
  """

    flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)]
    num_elems = array_ops.shape(flat_elems[0])[0]
    pack_elems = lambda x: nest.pack_sequence_as(structure=elems,
                                                 flat_sequence=x)
    flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)]
    pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x)
    accum_dtypes = [x.dtype for x in flat_initial]
    num_accums = len(flat_initial)

    # Types for counter, [outputs], [accumulators] loop arguments.
    if final_only:
        loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes
    else:
        loop_dtypes = [dtypes.int32, dtypes.int32
                       ] + accum_dtypes + accum_dtypes

    # TODO(tombagby): Update to tfe.defun
    def cond(i, num_elems, *args):
        del args
        return i >= 0 if reverse else i < num_elems

    # The loop *args are [output tensors] + [accumulator tensors] which must
    # be paired. Each output corresponds to one accumulator.
    def body(i, num_elems, *args):
        """Loop body."""
        i.set_shape([])
        if final_only:
            accum = args
        else:
            out, accum = args[:num_accums], args[num_accums:]
        slices = [array_ops.gather(e, i) for e in flat_elems]
        accum = fn(pack(accum), pack_elems(slices))
        flat_accum = nest.flatten(accum)
        if final_only:
            new_out = []
        else:
            update_i = i + 1 if inclusive and not reverse else i
            new_out = [
                inplace_ops.alias_inplace_update(x, update_i, y)
                for x, y in zip(out, flat_accum)
            ]
        i = i - 1 if reverse else i + 1
        return [i, num_elems] + new_out + flat_accum

    init_i = (array_ops.shape(flat_elems[0])[0] -
              1 if reverse else constant_op.constant(0, dtype=dtypes.int32))
    outputs = []
    if not final_only:
        num_outputs = array_ops.shape(
            flat_elems[0])[0] + (1 if inclusive else 0)
        for initial_accum in flat_initial:
            out_shape = array_ops.concat(
                [[num_outputs], array_ops.shape(initial_accum)], 0)
            out = inplace_ops.empty(out_shape,
                                    dtype=initial_accum.dtype,
                                    init=True)
            if inclusive:
                out = inplace_ops.alias_inplace_add(
                    out, init_i + (1 if reverse else 0), initial_accum)
            outputs.append(out)
    loop_in = [init_i, num_elems] + outputs + flat_initial
    hostmem = [
        i for i, x in enumerate(loop_in)
        if x.dtype.base_dtype in (dtypes.int32, dtypes.int64)
    ]

    if context.executing_eagerly():
        loop_results = loop_in
        while cond(*loop_results):
            loop_results = body(*loop_results)
    else:
        # TODO(tombagby): Update to while_v2.
        cond = function.Defun(*loop_dtypes)(cond)
        body = function.Defun(*loop_dtypes)(body)
        loop_results = functional_ops.While(loop_in,
                                            cond,
                                            body,
                                            hostmem=hostmem)
    out = loop_results[2:num_accums + 2]
    return pack(out)
コード例 #10
0
 def Run(sess, n):
   return sess.run(functional_ops.While([n, 0], Cond, Body))[1]
コード例 #11
0
    start = time.time()
    print 'Start: %ss' % start
    f()
    end = time.time()
    print 'End: %ss' % end
    print 'Elapsed: %ss' % (end - start)


with ops.Graph().as_default() as g:
    x_init = tf.constant(0.2, shape=[64, 512])
    w = tf.constant(0.2, shape=[512, 512])
    b = tf.constant(0.2, shape=[512])

    @function.Defun(dtypes.int32, *[dtypes.float32] * 3)
    def Cond(i, x, w, b):
        return i < 1000

    @function.Defun(dtypes.int32, *[dtypes.float32] * 3)
    def Body(i, x, w, b):
        return i + 1, tf.matmul(x, w) + b, w, b

    y = functional_ops.While([0, x_init, w, b], Cond, Body)

    def main():
        with tf.Session(graph=g) as sess:
            sess.run(y)

    time_this(main)

    tf.train.write_graph(g, '.', 'python_graph.pbtxt')