def FProp(self, theta, inputs, *extra_inputs): initial_step_seed = py_utils.GetStepSeed() final_step_seed = py_utils.GenerateSeedFromName( tf.no_op(name='new_step_seed').name) num_layers = len(self.sub_layers) def Bak(inputs, outputs, d_outputs): """Backward step.""" del inputs # unused output_acts, step_seeds = outputs d_outputs = d_outputs[0] d_layer_thetas = [] for layer_idx in reversed(range(num_layers)): f_seed, g_seed = step_seeds[layer_idx] layer = self.sub_layers[layer_idx] layer_theta = theta.sub_layers[layer_idx] input_acts, d_inputs, d_theta = layer.ReverseAndGrad( layer_theta, output_acts, d_outputs, f_seed, g_seed, *extra_inputs) d_layer_thetas.append(d_theta) # Passes reconstructed inputs to the previous layer. output_acts = input_acts d_outputs = d_inputs py_utils.ResetStepSeed(final_step_seed) d_theta = py_utils.NestedMap() d_theta.sub_layers = list(reversed(d_layer_thetas)) extra_grads = [tf.zeros_like(t) for t in extra_inputs] return [ tf.zeros_like(initial_step_seed), d_theta, d_inputs, extra_grads ] def Fwd(xs): """Forward pass.""" initial_step_seed, theta, acts, extra_inputs = xs py_utils.ResetStepSeed(initial_step_seed) layer_step_seeds = [] for layer_theta, layer in zip(theta.sub_layers, self.sub_layers): acts, f_seed, g_seed = layer.FProp(layer_theta, acts, *extra_inputs) layer_step_seeds += [(f_seed, g_seed)] return [acts, layer_step_seeds] if self.params.custom_gradient: acts, _ = py_utils.CallDefun( Fwd, [initial_step_seed, theta, inputs, extra_inputs], Bak) py_utils.ResetStepSeed(final_step_seed) return acts else: acts = inputs for layer_theta, layer in zip(theta.sub_layers, self.sub_layers): acts, _, _ = layer.FProp(layer_theta, acts, *extra_inputs) return acts
def GetOpSeedPair(op_seed=None): with tf.name_scope('op_seed') as scope: mb_tensor = GetOverWriteGlobalStep() seeds = tf.stack( [tf.cast(mb_tensor, tf.int32), py_utils.GenerateSeedFromName(scope)]) if op_seed is not None: seeds += op_seed return seeds
def GenerateStepSeedPair(p, unused_global_step=None, op_seed=None): """Override py_utils.GenerateStepSeedPair to use GetOverWriteGlobalStep.""" seed_dtype = tf.int32 if py_utils.use_tpu() else tf.int64 if p.is_inference and p.random_seed is None: # Unlike tf.random*, stateless random ops are completely determined by the # passed-in seeds. This means at inference time the same inputs will produce # the same outputs, even if the model is supposed to have randomness such as # dropout during inference. We inject additional randomness only during # inference if the graph is exported with random_seed=None as a workaround. return tf.random_uniform([2], maxval=seed_dtype.max, dtype=seed_dtype) with tf.name_scope('op_seed') as scope: global_step = tf.cast(GetOverWriteGlobalStep(), seed_dtype) step_seed = tf.cast(py_utils.GenerateSeedFromName(scope), seed_dtype) seeds = tf.stack([global_step, step_seed]) if p.random_seed is not None: seeds += p.random_seed if op_seed is not None: seeds += op_seed return seeds
def GetStepSeed(): """Override py_utils.GetIncStepSeed to use seed generated by name scope.""" with tf.name_scope('op_seed') as scope: return py_utils.GenerateSeedFromName(scope)