def CellGrad(theta, state0, inputs, extras, dstate1): """Default gradient function for cell_fn.""" state1, extras = cell_fn(theta, state0, inputs) assert isinstance(state1, py_utils.NestedMap), ('%s' % state1) assert isinstance(extras, py_utils.NestedMap), ('%s' % extras) # NOTE: The default grad function recomputes the forward # function and does not take advantage of 'extras' returned by # the forward function. # Assert that if captured inputs were given, they match the actual # tensors passed to the function we are compiled into. Must come after # the call to cell_fn, which does the capture. _AssertSameTensors(function.get_extra_inputs(), implicit_captures.Flatten()) # Extract the internal captured tensor placeholders within the Defun # we are running in. (captured, ) = _Pack(function.get_extra_args(), [implicit_captures]) ys = _Flatten([state1]) xs = _Flatten([theta, state0, inputs, captured]) grad_ys = _Flatten([dstate1]) grads = tf.gradients(ys=ys, xs=xs, grad_ys=grad_ys) return _ConvertNoneGradientToZeros( [theta, state0, inputs, captured], _Pack(grads, [theta, state0, inputs, captured]))
def _FlatOutputProcessor(inputs): """Returns a flattened list of 'processor(inputs)'.""" output, bucketing_key = processor(inputs) if isinstance(output, list): assert output assert all(isinstance(x, tf.Tensor) for x in output), '{}'.format(output) else: assert isinstance(output, py_utils.NestedMap), '{}'.format(output) assert output assert all(isinstance(x, tf.Tensor) for x in output.Flatten()), '{}'.format( output.DebugString()) bucketing_key = tf.to_int32(bucketing_key) tf.logging.debug('Processor outputs=%s bucketing_key=%s', output, bucketing_key) output_tmpl.values = output flat_output_tmpl = output_tmpl.Flatten() tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl) tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s', function.get_extra_inputs(), function.get_extra_args(), function.get_extra_vars()) assert not function.get_extra_args(), ( 'fns {} is not pure: extra_args={}'.format( processor, function.get_extra_args())) return flat_output_tmpl + [bucketing_key]
def Bak(*args): """Backward step.""" (theta, state0, inputs, extras, d_state1) = _Pack(args, bak_sig) (dtheta, dstate0, dinputs, dcaptures) = self._cell_grad(theta, state0, inputs, extras, d_state1) _AssertIsCompatible(dtheta, self._theta) _AssertIsCompatible(dstate0, self._state) _AssertIsCompatible(dinputs, self._inputs) if dcaptures is None: # NOTE: Custom gradient fns can return None if they do not support # captured tensors. The return value is reserved for the future when # that may be supported. dcaptures = _EmptyLike(self._implicit_captures) _AssertIsCompatible(dcaptures, self._implicit_captures) # Make sure this function didn't capture anything different than the # cell_fn when reflected on at the beginning. Must come after the call # to cell_grad() which adds to the captured list. _AssertSameTensors(function.get_extra_inputs(), self._implicit_captures.Flatten()) (captured, ) = _Pack(function.get_extra_args(), [self._implicit_captures]) return _Flatten( _ConvertNoneGradientToZeros( [theta, state0, inputs, captured], [dtheta, dstate0, dinputs, dcaptures]))
def Backward(*args): """Backward pass for the recurrent net.""" # theta, state0, inputs are Forward's inputs. # acc_state is the accumulated 1st output of Forward. # acc_extras is the accumulated 2nd output of Forward. # d_acc_state is the gradient for acc_state. # d_state1 is the gradient for the final state computed by Forward. (theta, state0, inputs, acc_state, acc_extras, d_acc_state, d_state1) = _Pack(args, backward_sig) # Accumulators for gradients. d_theta = _EmptyLike(theta) d_inputs = _EmptyLike(inputs) d_captured = _EmptyLike(self._implicit_captures) # The sequence length. pad_begin, pad_end = _SeqPaddingLength(inputs) start = _SeqLenDim(inputs) - pad_end - 1 if py_utils.use_tpu(): dev_t = tf.to_int32(start) else: dev_t = tf.to_int64(start) run = functional_ops.For( start=start, limit=pad_begin - 1, delta=-1, inputs=[dev_t] + _Flatten([ theta, state0, inputs, acc_state, acc_extras, d_theta, d_state1, d_inputs, d_acc_state, d_captured, ]), body=BackwardLoopBody, rewrite_with_while=compiled) (theta, state0, inputs, acc_state, acc_extras, d_theta, d_state0, d_inputs, d_acc_state, d_captured) = _Pack(run[1:], bakloop_sig) # Make sure this function didn't capture anything different than the # cell_fn when reflected on at the beginning. Must come after the # call to BackwardLoopBody, which adds to the captured list. _AssertSameTensors(function.get_extra_inputs(), self._implicit_captures.Flatten()) if self._unused_acc_state: # Match the shape of gradient of the init_state. d_state0 = self._state.Transform(tf.zeros_like) return _Flatten([d_theta, d_state0, d_inputs, acc_extras, d_captured])
def _FlatOutputProcessor(inputs): """Returns a flattened list of 'processor(inputs)'.""" outputs = processor(inputs) tf.logging.debug('Processor outputs=%s', outputs) assert len(outputs) > 1, outputs # Add 'outputs' as a list so that each element will be flattened. output_tmpl.values = list(outputs) flat_outputs = output_tmpl.Flatten() tf.logging.debug('Processor flat outputs=%s', flat_outputs) tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s', function.get_extra_inputs(), function.get_extra_args(), function.get_extra_vars()) assert not function.get_extra_args(), ( 'fns {} is not pure: extra_args={}'.format( processor, function.get_extra_args())) return flat_outputs
def _FlatOutputProcessor(source_id, record): """Returns a flattened list of 'processor(inputs)'.""" processor_spec = tf_inspect.getargspec(processor) tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec) processor_args = set(processor_spec.args) - set(['self']) if len(processor_args) == 1: output, bucketing_key = processor(record) elif processor_args == set(['source_id', 'record']): output, bucketing_key = processor(source_id=source_id, record=record) else: raise ValueError( 'GenericInput: processor should take either a single arg ' 'or two args named as "source_id" and "record". ' 'Actual: %s' % processor_args) if isinstance(output, list): assert output assert all(isinstance(x, tf.Tensor) for x in output), '{}'.format(output) else: assert isinstance(output, py_utils.NestedMap), '{}'.format(output) assert output assert all(isinstance(x, tf.Tensor) for x in output.Flatten()), '{}'.format( output.DebugString()) bucketing_key = tf.cast(bucketing_key, tf.int32) tf.logging.debug('Processor outputs=%s bucketing_key=%s', output, bucketing_key) output_tmpl.out_values = output flat_output_tmpl = output_tmpl.Flatten() tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl) tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s', function.get_extra_inputs(), function.get_extra_args(), function.get_extra_vars()) assert not function.get_extra_args(), ( 'fns {} is not pure: extra_args={}'.format( processor, function.get_extra_args())) return flat_output_tmpl + [bucketing_key]
def BackwardLoopBody(t, limit, *args): """Backward loop body function.""" ( theta, orig_state0, inputs, acc_state, acc_extras, # End of forward params d_theta, d_state1, d_inputs, d_acc_state, d_captured) = _Pack(args, bakloop_sig) # The input recurrent state for time step t is previous time step's # output, or the original state0 when on time step 0. state_from_acc = _Index(acc_state, tf.maximum(tf.constant(0, t.dtype), t - 1)) state0 = functional_ops.If(tf.equal(t, tf.constant(0, t.dtype)), _Flatten([state_from_acc, orig_state0]), ReturnOrigState0, ReturnAccState) state0 = orig_state0.Pack(state0) # The external inputs for time step t. inputs_t = _Index(inputs, t) # The extras for time step t. extras_t = _Index(acc_extras, t) d_state1 = _Add(_Index(d_acc_state, t), d_state1) (d_theta_t, d_state0, d_inputs_t, d_captured_t) = _Pack( Bak(*_Flatten([theta, state0, inputs_t, extras_t, d_state1])), [ self._theta, self._state, self._inputs, self._implicit_captures ]) if self._unused_acc_state: # XLA IF op requires the same shape for if and else branches. d_state0 = d_state0.Transform(tf.reduce_sum) d_theta = _Add(d_theta, d_theta_t) d_inputs = _Update(d_inputs, d_inputs_t, t) d_captured = _Add(d_captured, d_captured_t) # Make sure this function didn't capture anything different than the # cell_fn when reflected on at the beginning. Must come after the call # to Bak() which adds to the captured list. _AssertSameTensors(function.get_extra_inputs(), self._implicit_captures.Flatten()) return [tf.subtract(t, 1), limit] + _Flatten([ theta, orig_state0, inputs, acc_state, acc_extras, # End of forward params d_theta, d_state0, d_inputs, d_acc_state, d_captured, ])