Example #1
0
 def Fwd(*args):
     (theta, state0, inputs) = _Pack(args, fwd_sig)
     state1, extras = self._cell_fn(theta, state0, inputs)
     assert not function.get_extra_args(), (
         'cell_fn is not pure with extra args: %s.' %
         (function.get_extra_args()))
     _AssertIsCompatible(state1, self._state)
     _AssertIsCompatible(extras, self._extras)
     return _Flatten([state1, extras])
Example #2
0
 def Fwd(*args):
   (theta, state0, inputs) = _Pack(args, fwd_sig)
   state1, extras = self._cell_fn(theta, state0, inputs)
   assert not function.get_extra_args(), (
       'cell_fn is not pure with extra args: %s.' %
       (function.get_extra_args()))
   _AssertIsCompatible(state1, self._state)
   _AssertIsCompatible(extras, self._extras)
   return _Flatten([state1, extras])
 def Bak(*args):
   """Backward step."""
   (theta, state0, inputs, extras, d_state1) = _Pack(args, bak_sig)
   (dtheta, dstate0, dinputs) = self._cell_grad(theta, state0, inputs,
                                                extras, d_state1)
   assert not function.get_extra_args(), (
       'cell_grad is not pure with extra args: %s.' %
       (function.get_extra_args()))
   _AssertIsCompatible(dtheta, self._theta)
   _AssertIsCompatible(dstate0, self._state)
   _AssertIsCompatible(dinputs, self._inputs)
   return _Flatten(
       _ConvertNoneGradientToZeros([theta, state0, inputs],
                                   [dtheta, dstate0, dinputs]))
Example #4
0
    def CellGrad(theta, state0, inputs, extras, dstate1):
        """Default gradient function for cell_fn."""
        state1, extras = cell_fn(theta, state0, inputs)
        assert isinstance(state1, py_utils.NestedMap), ('%s' % state1)
        assert isinstance(extras, py_utils.NestedMap), ('%s' % extras)
        # NOTE: The default grad function recomputes the forward
        # function and does not take advantage of 'extras' returned by
        # the forward function.

        # Assert that if captured inputs were given, they match the actual
        # tensors passed to the function we are compiled into. Must come after
        # the call to cell_fn, which does the capture.
        _AssertSameTensors(function.get_extra_inputs(),
                           implicit_captures.Flatten())

        # Extract the internal captured tensor placeholders within the Defun
        # we are running in.
        (captured, ) = _Pack(function.get_extra_args(), [implicit_captures])
        ys = _Flatten([state1])
        xs = _Flatten([theta, state0, inputs, captured])
        grad_ys = _Flatten([dstate1])
        grads = tf.gradients(ys=ys, xs=xs, grad_ys=grad_ys)
        return _ConvertNoneGradientToZeros(
            [theta, state0, inputs, captured],
            _Pack(grads, [theta, state0, inputs, captured]))
Example #5
0
        def Bak(*args):
            """Backward step."""
            (theta, state0, inputs, extras, d_state1) = _Pack(args, bak_sig)
            (dtheta, dstate0, dinputs,
             dcaptures) = self._cell_grad(theta, state0, inputs, extras,
                                          d_state1)
            _AssertIsCompatible(dtheta, self._theta)
            _AssertIsCompatible(dstate0, self._state)
            _AssertIsCompatible(dinputs, self._inputs)
            if dcaptures is None:
                # NOTE: Custom gradient fns can return None if they do not support
                # captured tensors. The return value is reserved for the future when
                # that may be supported.
                dcaptures = _EmptyLike(self._implicit_captures)
            _AssertIsCompatible(dcaptures, self._implicit_captures)

            # Make sure this function didn't capture anything different than the
            # cell_fn when reflected on at the beginning. Must come after the call
            # to cell_grad() which adds to the captured list.
            _AssertSameTensors(function.get_extra_inputs(),
                               self._implicit_captures.Flatten())

            (captured, ) = _Pack(function.get_extra_args(),
                                 [self._implicit_captures])
            return _Flatten(
                _ConvertNoneGradientToZeros(
                    [theta, state0, inputs, captured],
                    [dtheta, dstate0, dinputs, dcaptures]))
Example #6
0
 def _FlatOutputProcessor(inputs):
     """Returns a flattened list of 'processor(inputs)'."""
     outputs = processor(inputs)
     tf.logging.debug('Processor outputs=%s', outputs)
     assert len(outputs) > 1, outputs
     # Add 'outputs' as a list so that each element will be flattened.
     output_tmpl.values = list(outputs)
     flat_outputs = output_tmpl.Flatten()
     tf.logging.debug('Processor flat outputs=%s', flat_outputs)
     tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                      function.get_extra_inputs(),
                      function.get_extra_args(), function.get_extra_vars())
     assert not function.get_extra_args(), (
         'fns {} is not pure: extra_args={}'.format(
             processor, function.get_extra_args()))
     return flat_outputs
Example #7
0
 def Grad(x, y0):
   if use_forward_func:
     y = Model(x)
   else:
     y = _Model(x)
   loss = tf.reduce_mean(tf.reduce_sum(y0 * tf.log(y), 1), 0)
   dw, db = tf.gradients(loss, function.get_extra_args())
   cvars.extend(function.get_extra_vars())
   return loss, dw, db
Example #8
0
 def WhileBody(i, n, start, delta, *args):
     """A While wrapper for forbody that handles loop-carried captured inputs."""
     for_result = forbody(start + i * delta, *args)
     # Nullary functions return an Operation. Normal functions can't do this
     # because their return values are converted to Tensors.
     if isinstance(for_result, ops.Operation):
         for_result = ()
     # Unary functions return a single Tensor value.
     elif isinstance(for_result, ops.Tensor):
         for_result = (for_result, )
     extra_args = tuple(function.get_extra_args())
     return (i + 1, n, start, delta) + tuple(for_result) + extra_args
Example #9
0
 def Grad(x, y0):
   if use_forward_func:
     y = Model(x)
   else:
     y = _Model(x)
   loss = tf.reduce_mean(tf.reduce_sum(y0 * tf.log(y), 1), 0)
   arg_w, arg_b = function.get_extra_args()
   self.assertEqual(arg_w.get_shape(), tf.TensorShape([64, 64]))
   self.assertEqual(arg_b.get_shape(), tf.TensorShape([64]))
   dw, db = tf.gradients(loss, [arg_w, arg_b])
   cvars.extend(function.get_extra_vars())
   return loss, dw, db
Example #10
0
 def WhileBody(i, n, start, delta, *args):
   """A While wrapper for forbody that handles loop-carried captured inputs."""
   for_result = forbody(start + i * delta, *args)
   # Nullary functions return an Operation. Normal functions can't do this
   # because their return values are converted to Tensors.
   if isinstance(for_result, ops.Operation):
     for_result = ()
   # Unary functions return a single Tensor value.
   elif isinstance(for_result, ops.Tensor):
     for_result = (for_result,)
   extra_args = tuple(function.get_extra_args())
   return (i + 1, n, start, delta) + tuple(for_result) + extra_args
Example #11
0
 def Grad(x, y0):
     if use_forward_func:
         y = Model(x)
     else:
         y = _Model(x)
     loss = tf.reduce_mean(tf.reduce_sum(y0 * tf.log(y), 1), 0)
     arg_w, arg_b = function.get_extra_args()
     self.assertEqual(arg_w.get_shape(), tf.TensorShape([64, 64]))
     self.assertEqual(arg_b.get_shape(), tf.TensorShape([64]))
     dw, db = tf.gradients(loss, [arg_w, arg_b])
     cvars.extend(function.get_extra_vars())
     return loss, dw, db
Example #12
0
 def _FlatOutputProcessor(source_id, record):
     """Returns a flattened list of 'processor(inputs)'."""
     processor_spec = tf_inspect.getargspec(processor)
     tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec)
     processor_args = set(processor_spec.args) - set(['self'])
     if len(processor_args) == 1:
         output, bucketing_key = processor(record)
     elif processor_args == set(['source_id', 'record']):
         output, bucketing_key = processor(source_id=source_id,
                                           record=record)
     else:
         raise ValueError(
             'GenericInput: processor should take either a single arg '
             'or two args named as "source_id" and "record". '
             'Actual: %s' % processor_args)
     if isinstance(output, list):
         assert output
         assert all(isinstance(x, tf.Tensor)
                    for x in output), '{}'.format(output)
     else:
         assert isinstance(output, py_utils.NestedMap), '{}'.format(output)
         assert output
         assert all(isinstance(x, tf.Tensor)
                    for x in output.Flatten()), '{}'.format(
                        output.DebugString())
     bucketing_key = tf.cast(bucketing_key, tf.int32)
     tf.logging.debug('Processor outputs=%s bucketing_key=%s', output,
                      bucketing_key)
     output_tmpl.out_values = output
     flat_output_tmpl = output_tmpl.Flatten()
     tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl)
     tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                      function.get_extra_inputs(),
                      function.get_extra_args(), function.get_extra_vars())
     assert not function.get_extra_args(), (
         'fns {} is not pure: extra_args={}'.format(
             processor, function.get_extra_args()))
     return flat_output_tmpl + [bucketing_key]
Example #13
0
 def BodyWrapper(*args):
     """A wrapper for body that handles loop-carried captured inputs."""
     body_result = body(*args)
     extra_args = tuple(function.get_extra_args())
     # Nullary functions return an Operation. Normal functions can't do this
     # because their return values are converted to Tensors.
     if isinstance(body_result, ops.Operation):
         return extra_args
     # Unary functions return a single Tensor value.
     elif not isinstance(body_result, tuple):
         return (body_result, ) + extra_args
     # N-ary functions return a tuple of Tensors.
     else:
         return body_result + extra_args
Example #14
0
 def Wrapper(*args):
   """A wrapper that handles loop-carried captured inputs."""
   result = func(*args)
   extra_args = tuple(function.get_extra_args())
   # Nullary functions return an Operation. Normal functions can't do this
   # because their return values are converted to Tensors.
   if isinstance(result, ops.Operation):
     return extra_args
   # Unary functions return a single Tensor value.
   elif not isinstance(result, tuple):
     return (result,) + extra_args
   # N-ary functions return a tuple of Tensors.
   else:
     return result + extra_args