예제 #1
0
    def CellGrad(theta, state0, inputs, extras, dstate1):
        """Default gradient function for cell_fn."""
        state1, extras = cell_fn(theta, state0, inputs)
        assert isinstance(state1, py_utils.NestedMap), ('%s' % state1)
        assert isinstance(extras, py_utils.NestedMap), ('%s' % extras)
        # NOTE: The default grad function recomputes the forward
        # function and does not take advantage of 'extras' returned by
        # the forward function.

        # Assert that if captured inputs were given, they match the actual
        # tensors passed to the function we are compiled into. Must come after
        # the call to cell_fn, which does the capture.
        _AssertSameTensors(function.get_extra_inputs(),
                           implicit_captures.Flatten())

        # Extract the internal captured tensor placeholders within the Defun
        # we are running in.
        (captured, ) = _Pack(function.get_extra_args(), [implicit_captures])
        ys = _Flatten([state1])
        xs = _Flatten([theta, state0, inputs, captured])
        grad_ys = _Flatten([dstate1])
        grads = tf.gradients(ys=ys, xs=xs, grad_ys=grad_ys)
        return _ConvertNoneGradientToZeros(
            [theta, state0, inputs, captured],
            _Pack(grads, [theta, state0, inputs, captured]))
예제 #2
0
 def _FlatOutputProcessor(inputs):
     """Returns a flattened list of 'processor(inputs)'."""
     output, bucketing_key = processor(inputs)
     if isinstance(output, list):
         assert output
         assert all(isinstance(x, tf.Tensor)
                    for x in output), '{}'.format(output)
     else:
         assert isinstance(output, py_utils.NestedMap), '{}'.format(output)
         assert output
         assert all(isinstance(x, tf.Tensor)
                    for x in output.Flatten()), '{}'.format(
                        output.DebugString())
     bucketing_key = tf.to_int32(bucketing_key)
     tf.logging.debug('Processor outputs=%s bucketing_key=%s', output,
                      bucketing_key)
     output_tmpl.values = output
     flat_output_tmpl = output_tmpl.Flatten()
     tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl)
     tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                      function.get_extra_inputs(),
                      function.get_extra_args(), function.get_extra_vars())
     assert not function.get_extra_args(), (
         'fns {} is not pure: extra_args={}'.format(
             processor, function.get_extra_args()))
     return flat_output_tmpl + [bucketing_key]
예제 #3
0
        def Bak(*args):
            """Backward step."""
            (theta, state0, inputs, extras, d_state1) = _Pack(args, bak_sig)
            (dtheta, dstate0, dinputs,
             dcaptures) = self._cell_grad(theta, state0, inputs, extras,
                                          d_state1)
            _AssertIsCompatible(dtheta, self._theta)
            _AssertIsCompatible(dstate0, self._state)
            _AssertIsCompatible(dinputs, self._inputs)
            if dcaptures is None:
                # NOTE: Custom gradient fns can return None if they do not support
                # captured tensors. The return value is reserved for the future when
                # that may be supported.
                dcaptures = _EmptyLike(self._implicit_captures)
            _AssertIsCompatible(dcaptures, self._implicit_captures)

            # Make sure this function didn't capture anything different than the
            # cell_fn when reflected on at the beginning. Must come after the call
            # to cell_grad() which adds to the captured list.
            _AssertSameTensors(function.get_extra_inputs(),
                               self._implicit_captures.Flatten())

            (captured, ) = _Pack(function.get_extra_args(),
                                 [self._implicit_captures])
            return _Flatten(
                _ConvertNoneGradientToZeros(
                    [theta, state0, inputs, captured],
                    [dtheta, dstate0, dinputs, dcaptures]))
예제 #4
0
    def Backward(*args):
      """Backward pass for the recurrent net."""
      # theta, state0, inputs are Forward's inputs.
      # acc_state is the accumulated 1st output of Forward.
      # acc_extras is the accumulated 2nd output of Forward.
      # d_acc_state is the gradient for acc_state.
      # d_state1 is the gradient for the final state computed by Forward.
      (theta, state0, inputs, acc_state, acc_extras, d_acc_state,
       d_state1) = _Pack(args, backward_sig)

      # Accumulators for gradients.
      d_theta = _EmptyLike(theta)
      d_inputs = _EmptyLike(inputs)
      d_captured = _EmptyLike(self._implicit_captures)

      # The sequence length.
      pad_begin, pad_end = _SeqPaddingLength(inputs)
      start = _SeqLenDim(inputs) - pad_end - 1

      if py_utils.use_tpu():
        dev_t = tf.to_int32(start)
      else:
        dev_t = tf.to_int64(start)
      run = functional_ops.For(
          start=start,
          limit=pad_begin - 1,
          delta=-1,
          inputs=[dev_t] + _Flatten([
              theta,
              state0,
              inputs,
              acc_state,
              acc_extras,
              d_theta,
              d_state1,
              d_inputs,
              d_acc_state,
              d_captured,
          ]),
          body=BackwardLoopBody,
          rewrite_with_while=compiled)

      (theta, state0, inputs, acc_state, acc_extras, d_theta, d_state0,
       d_inputs, d_acc_state, d_captured) = _Pack(run[1:], bakloop_sig)

      # Make sure this function didn't capture anything different than the
      # cell_fn when reflected on at the beginning. Must come after the
      # call to BackwardLoopBody, which adds to the captured list.
      _AssertSameTensors(function.get_extra_inputs(),
                         self._implicit_captures.Flatten())

      if self._unused_acc_state:
        # Match the shape of gradient of the init_state.
        d_state0 = self._state.Transform(tf.zeros_like)
      return _Flatten([d_theta, d_state0, d_inputs, acc_extras, d_captured])
예제 #5
0
 def _FlatOutputProcessor(inputs):
     """Returns a flattened list of 'processor(inputs)'."""
     outputs = processor(inputs)
     tf.logging.debug('Processor outputs=%s', outputs)
     assert len(outputs) > 1, outputs
     # Add 'outputs' as a list so that each element will be flattened.
     output_tmpl.values = list(outputs)
     flat_outputs = output_tmpl.Flatten()
     tf.logging.debug('Processor flat outputs=%s', flat_outputs)
     tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                      function.get_extra_inputs(),
                      function.get_extra_args(), function.get_extra_vars())
     assert not function.get_extra_args(), (
         'fns {} is not pure: extra_args={}'.format(
             processor, function.get_extra_args()))
     return flat_outputs
예제 #6
0
 def _FlatOutputProcessor(source_id, record):
     """Returns a flattened list of 'processor(inputs)'."""
     processor_spec = tf_inspect.getargspec(processor)
     tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec)
     processor_args = set(processor_spec.args) - set(['self'])
     if len(processor_args) == 1:
         output, bucketing_key = processor(record)
     elif processor_args == set(['source_id', 'record']):
         output, bucketing_key = processor(source_id=source_id,
                                           record=record)
     else:
         raise ValueError(
             'GenericInput: processor should take either a single arg '
             'or two args named as "source_id" and "record". '
             'Actual: %s' % processor_args)
     if isinstance(output, list):
         assert output
         assert all(isinstance(x, tf.Tensor)
                    for x in output), '{}'.format(output)
     else:
         assert isinstance(output, py_utils.NestedMap), '{}'.format(output)
         assert output
         assert all(isinstance(x, tf.Tensor)
                    for x in output.Flatten()), '{}'.format(
                        output.DebugString())
     bucketing_key = tf.cast(bucketing_key, tf.int32)
     tf.logging.debug('Processor outputs=%s bucketing_key=%s', output,
                      bucketing_key)
     output_tmpl.out_values = output
     flat_output_tmpl = output_tmpl.Flatten()
     tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl)
     tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                      function.get_extra_inputs(),
                      function.get_extra_args(), function.get_extra_vars())
     assert not function.get_extra_args(), (
         'fns {} is not pure: extra_args={}'.format(
             processor, function.get_extra_args()))
     return flat_output_tmpl + [bucketing_key]
예제 #7
0
        def BackwardLoopBody(t, limit, *args):
            """Backward loop body function."""
            (
                theta,
                orig_state0,
                inputs,
                acc_state,
                acc_extras,
                # End of forward params
                d_theta,
                d_state1,
                d_inputs,
                d_acc_state,
                d_captured) = _Pack(args, bakloop_sig)

            # The input recurrent state for time step t is previous time step's
            # output, or the original state0 when on time step 0.
            state_from_acc = _Index(acc_state,
                                    tf.maximum(tf.constant(0, t.dtype), t - 1))
            state0 = functional_ops.If(tf.equal(t, tf.constant(0, t.dtype)),
                                       _Flatten([state_from_acc, orig_state0]),
                                       ReturnOrigState0, ReturnAccState)
            state0 = orig_state0.Pack(state0)

            # The external inputs for time step t.
            inputs_t = _Index(inputs, t)
            # The extras for time step t.
            extras_t = _Index(acc_extras, t)

            d_state1 = _Add(_Index(d_acc_state, t), d_state1)
            (d_theta_t, d_state0, d_inputs_t, d_captured_t) = _Pack(
                Bak(*_Flatten([theta, state0, inputs_t, extras_t, d_state1])),
                [
                    self._theta, self._state, self._inputs,
                    self._implicit_captures
                ])

            if self._unused_acc_state:
                # XLA IF op requires the same shape for if and else branches.
                d_state0 = d_state0.Transform(tf.reduce_sum)
            d_theta = _Add(d_theta, d_theta_t)
            d_inputs = _Update(d_inputs, d_inputs_t, t)
            d_captured = _Add(d_captured, d_captured_t)

            # Make sure this function didn't capture anything different than the
            # cell_fn when reflected on at the beginning. Must come after the call
            # to Bak() which adds to the captured list.
            _AssertSameTensors(function.get_extra_inputs(),
                               self._implicit_captures.Flatten())

            return [tf.subtract(t, 1), limit] + _Flatten([
                theta,
                orig_state0,
                inputs,
                acc_state,
                acc_extras,
                # End of forward params
                d_theta,
                d_state0,
                d_inputs,
                d_acc_state,
                d_captured,
            ])