def Fwd(*args): (theta, state0, inputs) = _Pack(args, fwd_sig) state1, extras = self._cell_fn(theta, state0, inputs) assert not function.get_extra_args(), ( 'cell_fn is not pure with extra args: %s.' % (function.get_extra_args())) _AssertIsCompatible(state1, self._state) _AssertIsCompatible(extras, self._extras) return _Flatten([state1, extras])
def Bak(*args): """Backward step.""" (theta, state0, inputs, extras, d_state1) = _Pack(args, bak_sig) (dtheta, dstate0, dinputs) = self._cell_grad(theta, state0, inputs, extras, d_state1) assert not function.get_extra_args(), ( 'cell_grad is not pure with extra args: %s.' % (function.get_extra_args())) _AssertIsCompatible(dtheta, self._theta) _AssertIsCompatible(dstate0, self._state) _AssertIsCompatible(dinputs, self._inputs) return _Flatten( _ConvertNoneGradientToZeros([theta, state0, inputs], [dtheta, dstate0, dinputs]))
def CellGrad(theta, state0, inputs, extras, dstate1): """Default gradient function for cell_fn.""" state1, extras = cell_fn(theta, state0, inputs) assert isinstance(state1, py_utils.NestedMap), ('%s' % state1) assert isinstance(extras, py_utils.NestedMap), ('%s' % extras) # NOTE: The default grad function recomputes the forward # function and does not take advantage of 'extras' returned by # the forward function. # Assert that if captured inputs were given, they match the actual # tensors passed to the function we are compiled into. Must come after # the call to cell_fn, which does the capture. _AssertSameTensors(function.get_extra_inputs(), implicit_captures.Flatten()) # Extract the internal captured tensor placeholders within the Defun # we are running in. (captured, ) = _Pack(function.get_extra_args(), [implicit_captures]) ys = _Flatten([state1]) xs = _Flatten([theta, state0, inputs, captured]) grad_ys = _Flatten([dstate1]) grads = tf.gradients(ys=ys, xs=xs, grad_ys=grad_ys) return _ConvertNoneGradientToZeros( [theta, state0, inputs, captured], _Pack(grads, [theta, state0, inputs, captured]))
def Bak(*args): """Backward step.""" (theta, state0, inputs, extras, d_state1) = _Pack(args, bak_sig) (dtheta, dstate0, dinputs, dcaptures) = self._cell_grad(theta, state0, inputs, extras, d_state1) _AssertIsCompatible(dtheta, self._theta) _AssertIsCompatible(dstate0, self._state) _AssertIsCompatible(dinputs, self._inputs) if dcaptures is None: # NOTE: Custom gradient fns can return None if they do not support # captured tensors. The return value is reserved for the future when # that may be supported. dcaptures = _EmptyLike(self._implicit_captures) _AssertIsCompatible(dcaptures, self._implicit_captures) # Make sure this function didn't capture anything different than the # cell_fn when reflected on at the beginning. Must come after the call # to cell_grad() which adds to the captured list. _AssertSameTensors(function.get_extra_inputs(), self._implicit_captures.Flatten()) (captured, ) = _Pack(function.get_extra_args(), [self._implicit_captures]) return _Flatten( _ConvertNoneGradientToZeros( [theta, state0, inputs, captured], [dtheta, dstate0, dinputs, dcaptures]))
def _FlatOutputProcessor(inputs): """Returns a flattened list of 'processor(inputs)'.""" outputs = processor(inputs) tf.logging.debug('Processor outputs=%s', outputs) assert len(outputs) > 1, outputs # Add 'outputs' as a list so that each element will be flattened. output_tmpl.values = list(outputs) flat_outputs = output_tmpl.Flatten() tf.logging.debug('Processor flat outputs=%s', flat_outputs) tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s', function.get_extra_inputs(), function.get_extra_args(), function.get_extra_vars()) assert not function.get_extra_args(), ( 'fns {} is not pure: extra_args={}'.format( processor, function.get_extra_args())) return flat_outputs
def Grad(x, y0): if use_forward_func: y = Model(x) else: y = _Model(x) loss = tf.reduce_mean(tf.reduce_sum(y0 * tf.log(y), 1), 0) dw, db = tf.gradients(loss, function.get_extra_args()) cvars.extend(function.get_extra_vars()) return loss, dw, db
def WhileBody(i, n, start, delta, *args): """A While wrapper for forbody that handles loop-carried captured inputs.""" for_result = forbody(start + i * delta, *args) # Nullary functions return an Operation. Normal functions can't do this # because their return values are converted to Tensors. if isinstance(for_result, ops.Operation): for_result = () # Unary functions return a single Tensor value. elif isinstance(for_result, ops.Tensor): for_result = (for_result, ) extra_args = tuple(function.get_extra_args()) return (i + 1, n, start, delta) + tuple(for_result) + extra_args
def Grad(x, y0): if use_forward_func: y = Model(x) else: y = _Model(x) loss = tf.reduce_mean(tf.reduce_sum(y0 * tf.log(y), 1), 0) arg_w, arg_b = function.get_extra_args() self.assertEqual(arg_w.get_shape(), tf.TensorShape([64, 64])) self.assertEqual(arg_b.get_shape(), tf.TensorShape([64])) dw, db = tf.gradients(loss, [arg_w, arg_b]) cvars.extend(function.get_extra_vars()) return loss, dw, db
def WhileBody(i, n, start, delta, *args): """A While wrapper for forbody that handles loop-carried captured inputs.""" for_result = forbody(start + i * delta, *args) # Nullary functions return an Operation. Normal functions can't do this # because their return values are converted to Tensors. if isinstance(for_result, ops.Operation): for_result = () # Unary functions return a single Tensor value. elif isinstance(for_result, ops.Tensor): for_result = (for_result,) extra_args = tuple(function.get_extra_args()) return (i + 1, n, start, delta) + tuple(for_result) + extra_args
def _FlatOutputProcessor(source_id, record): """Returns a flattened list of 'processor(inputs)'.""" processor_spec = tf_inspect.getargspec(processor) tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec) processor_args = set(processor_spec.args) - set(['self']) if len(processor_args) == 1: output, bucketing_key = processor(record) elif processor_args == set(['source_id', 'record']): output, bucketing_key = processor(source_id=source_id, record=record) else: raise ValueError( 'GenericInput: processor should take either a single arg ' 'or two args named as "source_id" and "record". ' 'Actual: %s' % processor_args) if isinstance(output, list): assert output assert all(isinstance(x, tf.Tensor) for x in output), '{}'.format(output) else: assert isinstance(output, py_utils.NestedMap), '{}'.format(output) assert output assert all(isinstance(x, tf.Tensor) for x in output.Flatten()), '{}'.format( output.DebugString()) bucketing_key = tf.cast(bucketing_key, tf.int32) tf.logging.debug('Processor outputs=%s bucketing_key=%s', output, bucketing_key) output_tmpl.out_values = output flat_output_tmpl = output_tmpl.Flatten() tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl) tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s', function.get_extra_inputs(), function.get_extra_args(), function.get_extra_vars()) assert not function.get_extra_args(), ( 'fns {} is not pure: extra_args={}'.format( processor, function.get_extra_args())) return flat_output_tmpl + [bucketing_key]
def BodyWrapper(*args): """A wrapper for body that handles loop-carried captured inputs.""" body_result = body(*args) extra_args = tuple(function.get_extra_args()) # Nullary functions return an Operation. Normal functions can't do this # because their return values are converted to Tensors. if isinstance(body_result, ops.Operation): return extra_args # Unary functions return a single Tensor value. elif not isinstance(body_result, tuple): return (body_result, ) + extra_args # N-ary functions return a tuple of Tensors. else: return body_result + extra_args
def Wrapper(*args): """A wrapper that handles loop-carried captured inputs.""" result = func(*args) extra_args = tuple(function.get_extra_args()) # Nullary functions return an Operation. Normal functions can't do this # because their return values are converted to Tensors. if isinstance(result, ops.Operation): return extra_args # Unary functions return a single Tensor value. elif not isinstance(result, tuple): return (result,) + extra_args # N-ary functions return a tuple of Tensors. else: return result + extra_args