def FProp(self, theta, inputs, *extra_inputs):

        initial_step_seed = py_utils.GetStepSeed()
        final_step_seed = py_utils.GenerateSeedFromName(
            tf.no_op(name='new_step_seed').name)
        num_layers = len(self.sub_layers)

        def Bak(inputs, outputs, d_outputs):
            """Backward step."""
            del inputs  # unused
            output_acts, step_seeds = outputs
            d_outputs = d_outputs[0]

            d_layer_thetas = []
            for layer_idx in reversed(range(num_layers)):
                f_seed, g_seed = step_seeds[layer_idx]
                layer = self.sub_layers[layer_idx]
                layer_theta = theta.sub_layers[layer_idx]

                input_acts, d_inputs, d_theta = layer.ReverseAndGrad(
                    layer_theta, output_acts, d_outputs, f_seed, g_seed,
                    *extra_inputs)

                d_layer_thetas.append(d_theta)
                # Passes reconstructed inputs to the previous layer.
                output_acts = input_acts
                d_outputs = d_inputs
            py_utils.ResetStepSeed(final_step_seed)
            d_theta = py_utils.NestedMap()
            d_theta.sub_layers = list(reversed(d_layer_thetas))

            extra_grads = [tf.zeros_like(t) for t in extra_inputs]
            return [
                tf.zeros_like(initial_step_seed), d_theta, d_inputs,
                extra_grads
            ]

        def Fwd(xs):
            """Forward pass."""
            initial_step_seed, theta, acts, extra_inputs = xs

            py_utils.ResetStepSeed(initial_step_seed)
            layer_step_seeds = []

            for layer_theta, layer in zip(theta.sub_layers, self.sub_layers):
                acts, f_seed, g_seed = layer.FProp(layer_theta, acts,
                                                   *extra_inputs)
                layer_step_seeds += [(f_seed, g_seed)]
            return [acts, layer_step_seeds]

        if self.params.custom_gradient:
            acts, _ = py_utils.CallDefun(
                Fwd, [initial_step_seed, theta, inputs, extra_inputs], Bak)
            py_utils.ResetStepSeed(final_step_seed)
            return acts
        else:
            acts = inputs
            for layer_theta, layer in zip(theta.sub_layers, self.sub_layers):
                acts, _, _ = layer.FProp(layer_theta, acts, *extra_inputs)
            return acts
Exemple #2
0
def GetOpSeedPair(op_seed=None):
  with tf.name_scope('op_seed') as scope:
    mb_tensor = GetOverWriteGlobalStep()
    seeds = tf.stack(
        [tf.cast(mb_tensor, tf.int32),
         py_utils.GenerateSeedFromName(scope)])
    if op_seed is not None:
      seeds += op_seed
    return seeds
Exemple #3
0
def GenerateStepSeedPair(p, unused_global_step=None, op_seed=None):
  """Override py_utils.GenerateStepSeedPair to use GetOverWriteGlobalStep."""
  seed_dtype = tf.int32 if py_utils.use_tpu() else tf.int64
  if p.is_inference and p.random_seed is None:
    # Unlike tf.random*, stateless random ops are completely determined by the
    # passed-in seeds. This means at inference time the same inputs will produce
    # the same outputs, even if the model is supposed to have randomness such as
    # dropout during inference. We inject additional randomness only during
    # inference if the graph is exported with random_seed=None as a workaround.
    return tf.random_uniform([2], maxval=seed_dtype.max, dtype=seed_dtype)

  with tf.name_scope('op_seed') as scope:
    global_step = tf.cast(GetOverWriteGlobalStep(), seed_dtype)
    step_seed = tf.cast(py_utils.GenerateSeedFromName(scope), seed_dtype)
    seeds = tf.stack([global_step, step_seed])

    if p.random_seed is not None:
      seeds += p.random_seed
    if op_seed is not None:
      seeds += op_seed
    return seeds
Exemple #4
0
def GetStepSeed():
    """Override py_utils.GetIncStepSeed to use seed generated by name scope."""
    with tf.name_scope('op_seed') as scope:
        return py_utils.GenerateSeedFromName(scope)