def call_replica_local_fn(fn, *args, **kwargs):
  """Call a function that uses replica-local variables.

  This function correctly handles calling `fn` in a cross-replica
  context.

  Arguments:
    fn: The function to call.
    *args: Positional arguments to the `fn`.
    **kwargs: Keyword argument to `fn`.

  Returns:
    The result of calling `fn`.
  """
  # TODO(b/120571621): We want to avoid reductions here since
  # since TPUStrategy does not implement replica local variables.
  # Remove this hack once we support TPUReplicaLocalVariables.
  strategy = None
  if 'strategy' in kwargs:
    strategy = kwargs.pop('strategy')
  else:
    if ds_context.get_strategy():
      strategy = ds_context.get_strategy()

  is_tpu = is_tpu_strategy(strategy)
  if ((not is_tpu) and strategy and ds_context.in_cross_replica_context()):
    with strategy.scope():
      return strategy.extended.call_for_each_replica(fn, args, kwargs)
  return fn(*args, **kwargs)
Esempio n. 2
0
  def _fused_batch_norm(self, inputs, training):
    """Returns the output of fused batch norm."""
    beta = self.beta if self.center else self._beta_const
    gamma = self.gamma if self.scale else self._gamma_const

    def _fused_batch_norm_training():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          epsilon=self.epsilon,
          data_format=self._data_format)

    def _fused_batch_norm_inference():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=self.moving_variance,
          epsilon=self.epsilon,
          is_training=False,
          data_format=self._data_format)

    output, mean, variance = tf_utils.smart_cond(
        training, _fused_batch_norm_training, _fused_batch_norm_inference)
    if not self._bessels_correction_test_only:
      # Remove Bessel's correction to be consistent with non-fused batch norm.
      # Note that the variance computed by fused batch norm is
      # with Bessel's correction.
      sample_size = math_ops.cast(
          array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
      factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
      variance *= factor

    training_value = tf_utils.constant_value(training)
    if training_value is None:
      momentum = tf_utils.smart_cond(training,
                                     lambda: self.momentum,
                                     lambda: 1.0)
    else:
      momentum = ops.convert_to_tensor(self.momentum)
    if training_value or training_value is None:
      if distribution_strategy_context.in_cross_replica_context():
        strategy = distribution_strategy_context.get_strategy()
        mean_update = strategy.extended.update(
            self.moving_mean, self._assign_moving_average,
            (mean, self.momentum))
        variance_update = strategy.extended.update(
            self.moving_variance, self._assign_moving_average,
            (variance, self.momentum))
      else:
        mean_update = self._assign_moving_average(self.moving_mean, mean,
                                                  momentum)
        variance_update = self._assign_moving_average(self.moving_variance,
                                                      variance, momentum)
      self.add_update(mean_update, inputs=True)
      self.add_update(variance_update, inputs=True)

    return output
def _assert_in_default_state(t):
  t.assertIs(ds_context._get_default_replica_context(),
             ds_context.get_replica_context())
  t.assertIs(None, ds_context.get_cross_replica_context())
  t.assertFalse(ds_context.in_cross_replica_context())
  t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
  t.assertFalse(ds_context.has_strategy())
 def merge_fn(dist, s):
   self.assertIs(ds_context._get_default_strategy(), dist)
   self.assertIs(None, ds_context.get_replica_context())
   self.assertIs(dist, ds_context.get_cross_replica_context())
   self.assertTrue(ds_context.in_cross_replica_context())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertFalse(ds_context.has_strategy())
   return "foo_" + s
Esempio n. 5
0
 def set_non_tensor_output(self, name, output):
   """Set `output` with `name` to be captured as a non tensor output."""
   if distribution_strategy_context.in_cross_replica_context():
     self._non_tensor_outputs[name] = output
   else:
     def merge_fn(distribution, value):
       # NOTE(priyag): For non tensor outputs, we simply return all the values
       # in a list as reduction doesn't make sense on non tensors.
       self._non_tensor_outputs[name] = distribution.unwrap(value)
     distribution_strategy_context.get_replica_context().merge_call(
         merge_fn, args=(output,))
Esempio n. 6
0
def _create_keras_history_helper(tensors, processed_ops, created_layers):
  """Helper method for `create_keras_history`.

  Arguments:
    tensors: A structure of Tensors for which to create Keras metadata.
    processed_ops: Set. TensorFlow operations that have already been wrapped in
      `TensorFlowOpLayer` instances.
    created_layers: List. The `TensorFlowOpLayer` instances created.

  Returns:
    Tuple. First element is the updated set of TensorFlow Operations that
    have been wrapped in `TensorFlowOpLayer` instances. Second element is
    a list of the `TensorFlowOpLayer` instances created.
  """
  # Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
  # Cannot be imported at top because of circular dependencies.
  # TODO(omalleyt): Resolve circular dependency.
  from tensorflow.python.keras.engine import base_layer  # pylint: disable=g-import-not-at-top
  tensor_list = nest.flatten(tensors)
  for tensor in tensor_list:
    if getattr(tensor, '_keras_history', None) is not None:
      continue
    op = tensor.op  # The Op that created this Tensor.
    if op not in processed_ops:
      # Recursively set `_keras_history`.
      op_inputs = list(op.inputs)
      constants = {}
      layer_inputs = []
      for i, op_input in enumerate(op_inputs):
        if uses_keras_history(op_input):
          layer_inputs.append(op_input)
        else:
          # Treat any value not originating from a `keras.Input` as
          # a constant. Variables cannot be supported.
          if (distribution_strategy_context.in_cross_replica_context() and
              not ops.executing_eagerly_outside_functions()):
            # In Legacy Graph mode, evaluating here makes Session be
            # configured improperly.
            constants[i] = op_input
          else:
            constants[i] = backend.function([], op_input)([])
      processed_ops, created_layers = _create_keras_history_helper(
          layer_inputs, processed_ops, created_layers)
      name = op.name
      node_def = op.node_def.SerializeToString()
      op_layer = base_layer.TensorFlowOpLayer(
          node_def, constants=constants, name=name)
      created_layers.append(op_layer)
      op_layer._add_inbound_node(  # pylint: disable=protected-access
          layer_inputs, op.outputs)
      processed_ops.update([op])
  return processed_ops, created_layers
 def run_fn():
   replica_context = ds_context.get_replica_context()
   self.assertTrue(replica_context is not None)
   self.assertIs(None, ds_context.get_cross_replica_context())
   self.assertFalse(ds_context.in_cross_replica_context())
   self.assertTrue(ds_context.has_strategy())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, ds_context.get_replica_context())
     self.assertIs(dist, ds_context.get_cross_replica_context())
     self.assertTrue(ds_context.in_cross_replica_context())
     self.assertTrue(ds_context.has_strategy())
     self.assertIs(dist, ds_context.get_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
Esempio n. 9
0
  def set_last_step_output(self, name, output, reduce_op=None):
    """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      reduce_op: Reduction method to use to reduce outputs from multiple
        replicas. Required if `set_last_step_output` is called in a replica
        context. Optional in cross_replica_context.
        When present, the outputs from all the replicas are reduced using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and reduction is set, output
        must be a `PerReplica` value.
        The reduce method is also recorded in a dictionary
        `_last_step_outputs_reduce_ops` for later interpreting of the
        outputs as already reduced or not.
    """
    if distribution_strategy_context.in_cross_replica_context():
      self._last_step_outputs_reduce_ops[name] = reduce_op
      if reduce_op is None:
        self._last_step_outputs[name] = output
      else:
        distribution = distribution_strategy_context.get_strategy()
        self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
                                                            axis=None)
    else:
      assert reduce_op is not None
      def merge_fn(distribution, value):
        self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
                                                            axis=None)
        # Setting this inside the `merge_fn` because all replicas share the same
        # context object, so it's more robust to set it only once (even if all
        # the replicas are trying to set the same value).
        self._last_step_outputs_reduce_ops[name] = reduce_op

      distribution_strategy_context.get_replica_context().merge_call(
          merge_fn, args=(output,))
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    conditionally applies gradients if all gradient values are finite.
    Otherwise no update is performed (nor is `global_step` incremented).

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        `compute_gradients()`.
      global_step: Optional `Variable` to increment by one after the variables
        have been updated.
      name: Optional name for the returned operation.  Default to the name
        passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that conditionally applies the specified gradients. If
      `global_step` was not None, that operation also increments `global_step`.

    Raises:
      RuntimeError: If you should use `_distributed_apply()` instead.
    """
        if distribution_strategy_context.in_cross_replica_context():
            raise ValueError(
                'apply_gradients() must be called in a replica context.')

        if not self._doing_dynamic_loss_scaling():
            return self._optimizer.apply_gradients(grads_and_vars, global_step,
                                                   name)

        replica_context = distribution_strategy_context.get_replica_context()
        grads_and_vars = tuple(grads_and_vars)

        # TODO(nluehr) cleanup GraphKeys.TRAIN_OP
        return replica_context.merge_call(self._distributed_apply,
                                          args=(grads_and_vars, global_step,
                                                name))
Esempio n. 11
0
def _centering_weights(weight, weight_identity):
    # when using group norm
    # normalize weights variable may get better result
    centered_weight = weight_identity - tf.reduce_mean(weight_identity,
                                                       axis=[0, 1, 2])
    weight_norm = linalg_ops.norm(tf.cast(tf.reshape(
        centered_weight, [-1, centered_weight.shape[-1]]),
                                          dtype=tf.float32),
                                  ord=2,
                                  axis=-2)
    normed_weight = centered_weight / tf.cast(weight_norm, dtype=weight.dtype)

    if distribute_ctx.has_strategy():
        # Handle DistributionStrategy case.
        if distribute_ctx.in_cross_replica_context():
            raise RuntimeError(
                "Use `_distributed_apply()` instead of "
                "`apply_gradients()` in a cross-replica context.")

        assign_op = distribute_ctx.get_replica_context().merge_call(
            assign_vars, args=(weight, normed_weight))[0]
    else:
        assign_op = weight.assign(normed_weight)
    return [assign_op]
Esempio n. 12
0
 def g(strategy, z):
   g_traces.append(None)  # Only happens on trace.
   self.assertIs(strategy, self._strategy)
   self.assertTrue(distribution_strategy_context.in_cross_replica_context())
   self.assertIsInstance(z, mirrored_function_strategy.FnMergedValue)
   return z
    def gradient(self,
                 target,
                 sources,
                 output_gradients=None,
                 unconnected_gradients=UnconnectedGradients.NONE):
        """Computes the gradient using operations recorded in context of this tape.

    Uses the `LossScale` object provided in the constructor to scale `target`
    and then to unscale the resulting gradients.

    Args:
      target: a list or nested structure of Tensors or Variables to be
        differentiated.
      sources: a list or nested structure of Tensors or Variables. `target` will
        be differentiated against elements in `sources`.
      output_gradients: a list of gradients, one for each element of target.
        Defaults to None.
      unconnected_gradients: a value which can either hold 'none' or 'zero' and
        alters the value which will be returned if the target and sources are
        unconnected. The possible values and effects are detailed in
        'UnconnectedGradients' and it defaults to 'none'.

    Returns:
      a list or nested structure of Tensors (or IndexedSlices, or None),
      one for each element in `sources`. Returned structure is the same as
      the structure of `sources`. If non-finite gradients are encountered
      after dynamic scaling, the loss scale will be updated and the gradients
      recomputed until either finite gradients are encountered or the loss scale
      becomes 1.

    Raises:
      RuntimeError: if called inside the context of the tape, or if called more
       than once on a non-persistent tape.
      ValueError: if the target is a variable or if unconnected gradients is
       called with an unknown value.
    """
        if self._tape is None:  # pylint: disable=access-member-before-definition
            raise RuntimeError(
                "GradientTape.gradient can only be called once on "
                "non-persistent tapes.")
        if distribution_strategy_context.in_cross_replica_context():
            raise ValueError(
                "LossScaleGradientTape.gradient() must be called in a "
                "replica context.")

        if context.executing_eagerly():
            compute_gradients_until_finite = _compute_gradients_until_finite
        else:
            compute_gradients_until_finite = _compute_gradients_until_finite_autograph

        # Note: DistributionStrategy does not support running a while loop in a
        # replica context. So, we call `compute_gradients_until_finite` in a cross-
        # replica context.
        replica_context = distribution_strategy_context.get_replica_context()
        grads = replica_context.merge_call(
            compute_gradients_until_finite,
            args=(self, self._loss_scale, target, sources, output_gradients,
                  unconnected_gradients))

        if not self._outer_persistent:
            self._tape = None  # free up resources if a persistent tape was not needed
        return grads
Esempio n. 14
0
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        """See base class."""
        assignments = []
        if self.debiasing:
            if global_step is None:
                local_step = variable_scope.get_variable(
                    name="local_step",
                    shape=[],
                    dtype=tf.float32,
                    initializer=init_ops.zeros_initializer(),
                    trainable=False)
                update_local_step = local_step.assign_add(1)
            else:
                update_local_step = global_step
        for (grad, param) in grads_and_vars:
            if grad is None or param is None:
                continue

            param_name = _get_variable_name(param.name)
            # we use fp32 for calculating update
            grad_fp32 = tf.cast(grad, tf.float32)
            # We divide the gradient by the loss scaling
            grad_fp32 = grad_fp32 / self.loss_scaling

            m = tf.get_variable(name=param_name + "/adam_m",
                                shape=param.shape.as_list(),
                                dtype=tf.float32,
                                trainable=False,
                                initializer=tf.zeros_initializer())

            v = tf.get_variable(name=param_name + "/adam_v",
                                shape=param.shape.as_list(),
                                dtype=tf.float32,
                                trainable=False,
                                initializer=tf.zeros_initializer())

            # Standard Adam update.
            next_m = (tf.multiply(self.beta_1, m) +
                      tf.multiply(1.0 - self.beta_1, grad_fp32))

            next_v = (tf.multiply(self.beta_2, v) +
                      tf.multiply(1.0 - self.beta_2, tf.square(grad_fp32)))
            if self.debiasing:
                next_m_debiase = next_m / (
                    1.0 - tf.pow(self.beta_1, update_local_step))
                next_v_debiase = next_v / (
                    1.0 - tf.pow(self.beta_2, update_local_step))
            else:
                next_m_debiase = next_m
                next_v_debiase = next_v

            update = tf.cast(
                next_m_debiase / (tf.sqrt(next_v_debiase) + self.epsilon),
                param.dtype)

            # Just adding the square of the weights to the loss function is *not*
            # the correct way of using L2 regularization/weight decay with Adam,
            # since that will interact with the m and v parameters in strange ways.
            #
            # Instead we want to decay the weights in a manner that doesn't interact
            # with the m/v parameters. This is equivalent to adding the square
            # of the weights to the loss with plain (non-momentum) SGD.
            if _do_use_weight_decay(param_name, self.weight_decay_rate,
                                    self.exclude_from_weight_decay):
                update += self.weight_decay_rate * param

            update_with_lr = self.learning_rate * update

            next_param = param - update_with_lr

            if distribute_ctx.has_strategy():
                # Handle DistributionStrategy case.
                if distribute_ctx.in_cross_replica_context():
                    raise RuntimeError(
                        "Use `_distributed_apply()` instead of "
                        "`apply_gradients()` in a cross-replica context.")

                assign_params = distribute_ctx.get_replica_context(
                ).merge_call(assign_vars,
                             args=((param, m, v), (next_param, next_m,
                                                   next_v)))
            else:
                assign_params = [
                    param.assign(next_param),
                    m.assign(next_m),
                    v.assign(next_v)
                ]
            assignments.extend(assign_params)

            if _need_centering(param_name, self.darknet_gn, self.upsample_gn):
                with tf.control_dependencies(assign_params):
                    param_identity = tf.identity(param)
                centering_op = _centering_weights(param, param_identity)
                assignments.append(centering_op)

        if self.use_moving_avg:
            # using tf.train.ExponentialMovingAverage will make compiler produce many executables
            # and the program will run "load executable" for many times
            # so we write our own moving average
            # will use tf.train.ExponentialMovingAverage after we fix this
            assignments.extend(
                _create_moving_avg(grads_and_vars, self.moving_avg_decay))
        return tf.group(*assignments, name=name)
Esempio n. 15
0
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        `compute_gradients()`.
      global_step: Optional `Variable` to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step`
      was not None, that operation also increments `global_step`.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
      RuntimeError: If you should use `_distributed_apply()` instead.
    """
    # This is a default implementation of apply_gradients() that can be shared
    # by most optimizers.  It relies on the subclass implementing the following
    # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().

    # TODO(isaprykin): Get rid of `has_strategy()` check by
    # always calling _distributed_apply(), using the default distribution
    # as needed.
    if distribute_ctx.has_strategy():
      # Handle DistributionStrategy case.
      if distribute_ctx.in_cross_replica_context():
        raise RuntimeError("Use `_distributed_apply()` instead of "
                           "`apply_gradients()` in a cross-replica context.")

      grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
      return distribute_ctx.get_replica_context().merge_call(
          self._distributed_apply, args=(grads_and_vars, global_step, name))

    # No DistributionStrategy case.
    grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
    if not grads_and_vars:
      raise ValueError("No variables provided.")
    converted_grads_and_vars = []
    for g, v in grads_and_vars:
      if g is not None:
        try:
          # Convert the grad to Tensor or IndexedSlices if necessary.
          g = ops.convert_to_tensor_or_indexed_slices(g)
        except TypeError:
          raise TypeError(
              "Gradient must be convertible to a Tensor"
              " or IndexedSlices, or None: %s" % g)
        if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
          raise TypeError(
              "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
      p = _get_processor(v)
      converted_grads_and_vars.append((g, v, p))

    converted_grads_and_vars = tuple(converted_grads_and_vars)
    var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
    if not var_list:
      raise ValueError("No gradients provided for any variable: %s." %
                       ([str(v) for _, v, _ in converted_grads_and_vars],))
    with ops.init_scope():
      self._create_slots(var_list)
    update_ops = []
    with ops.name_scope(name, self._name) as name:
      self._prepare()
      for grad, var, processor in converted_grads_and_vars:
        if grad is None:
          continue
        # We colocate all ops created in _apply_dense or _apply_sparse
        # on the same device as the variable.
        # TODO(apassos): figure out how to get the variable name here.
        if context.executing_eagerly() or isinstance(
            var,
            resource_variable_ops.ResourceVariable) and not var._in_graph_mode:  # pylint: disable=protected-access
          scope_name = ""
        else:
          scope_name = var.op.name
        with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
          update_ops.append(processor.update_op(self, grad))
      if global_step is None:
        apply_updates = self._finish(update_ops, name)
      else:
        with ops.control_dependencies([self._finish(update_ops, "update")]):
          with ops.colocate_with(global_step):
            if isinstance(global_step, resource_variable_ops.ResourceVariable):
              # TODO(apassos): the implicit read in assign_add is slow; consider
              # making it less so.
              apply_updates = resource_variable_ops.assign_add_variable_op(
                  global_step.handle,
                  ops.convert_to_tensor(1, dtype=global_step.dtype),
                  name=name)
            else:
              apply_updates = state_ops.assign_add(global_step, 1, name=name)

      if not context.executing_eagerly():
        if isinstance(apply_updates, ops.Tensor):
          apply_updates = apply_updates.op
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        if apply_updates not in train_op:
          train_op.append(apply_updates)

      return apply_updates
Esempio n. 16
0
  def call(self, inputs, training=None):
    training = K.learning_phase()

    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.shape
    ndims = len(input_shape)
    reduction_axes = [i for i in range(ndims) if i not in self.axis]

    scale, offset = self.gamma, self.beta


    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = tf_utils.constant_value(training)
    if training_value is not False:

      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      keep_dims = len(self.axis) > 1
      mean, variance = self._moments(
          math_ops.cast(inputs, inputs.dtype),
          reduction_axes,
          keep_dims=keep_dims)

      moving_mean = self.moving_mean
      moving_variance = self.moving_variance

      mean = tf_utils.smart_cond(training,
                                 lambda: mean,
                                 lambda: ops.convert_to_tensor(moving_mean))
      variance = tf_utils.smart_cond(
          training,
          lambda: variance,
          lambda: ops.convert_to_tensor(moving_variance))

      new_mean, new_variance = mean, variance

      if ops.executing_eagerly_outside_functions(
      ) and distribution_strategy_context.has_strategy():
        inputs_size = array_ops.size(inputs)
      else:
        inputs_size = None

      if distribution_strategy_context.in_cross_replica_context():
        strategy = distribution_strategy_context.get_strategy()

        def _do_update(var, value):
          """Compute the updates for mean and variance."""
          return strategy.extended.update(
              var,
              self._assign_moving_average, (value, self.momentum, inputs_size),
              group=False)
        # We need to unwrap the moving_mean or moving_variance in the case of
        # training being false to match the output of true_fn and false_fn
        # in the smart cond.
        def mean_update():
          true_branch = lambda: _do_update(self.moving_mean, new_mean)
          false_branch = lambda: strategy.unwrap(self.moving_mean)
          return tf_utils.smart_cond(training, true_branch, false_branch)

        def variance_update():
          return tf_utils.smart_cond(
              training, lambda: _do_update(self.moving_variance, new_variance),
              lambda: strategy.unwrap(self.moving_variance))
      else:
        def _do_update(var, value):
          """Compute the updates for mean and variance."""
          return self._assign_moving_average(var, value, self.momentum,
                                             inputs_size)


        def mean_update():
          true_branch = lambda: _do_update(self.moving_mean, new_mean)
          false_branch = lambda: self.moving_mean
          return tf_utils.smart_cond(training, true_branch, false_branch)

        def variance_update():
          true_branch = lambda: _do_update(self.moving_variance, new_variance)
          false_branch = lambda: self.moving_variance
          return tf_utils.smart_cond(training, true_branch, false_branch)

      self.add_update(mean_update, inputs=True)
      self.add_update(variance_update, inputs=True)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    mean = math_ops.cast(mean, inputs.dtype)
    variance = math_ops.cast(variance, inputs.dtype)
    if offset is not None:
      offset = math_ops.cast(offset, inputs.dtype)
    if scale is not None:
      scale = math_ops.cast(scale, inputs.dtype)

    outputs = nn.batch_normalization(inputs,mean,variance,
                                     offset,
                                     scale,
                                     self.epsilon)

    return outputs
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        """Apply gradients to variables.

        This is the second part of `minimize()`. It returns an `Operation` that
        applies gradients.

        Args:
          grads_and_vars: List of (gradient, variable) pairs as returned by
            `compute_gradients()`.
          global_step: Optional `Variable` to increment by one after the
            variables have been updated.
          name: Optional name for the returned operation.  Default to the
            name passed to the `Optimizer` constructor.

        Returns:
          An `Operation` that applies the specified gradients. If `global_step`
          was not None, that operation also increments `global_step`.

        Raises:
          TypeError: If `grads_and_vars` is malformed.
          ValueError: If none of the variables have gradients.
          RuntimeError: If you should use `_distributed_apply()` instead.
        """
        # This is a default implementation of apply_gradients() that can be shared
        # by most optimizers.  It relies on the subclass implementing the following
        # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().

        # TODO(isaprykin): Get rid of `has_strategy()` check by
        # always calling _distributed_apply(), using the default distribution
        # as needed.

        self.epochStart = self.params.epochStart
        self.params.epochStart = False

        if distribute_ctx.has_strategy():
          # Handle DistributionStrategy case.
          if distribute_ctx.in_cross_replica_context():
            raise RuntimeError("Use `_distributed_apply()` instead of "
                               "`apply_gradients()` in a cross-replica context.")

          grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
          return distribute_ctx.get_replica_context().merge_call(
              self._distributed_apply, args=(grads_and_vars, global_step, name))

        # No DistributionStrategy case.
        grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
        if not grads_and_vars:
          raise ValueError("No variables provided.")
        converted_grads_and_vars = []

        

        for g, v in grads_and_vars:
          if g is not None:
            try:
              # Convert the grad to Tensor or IndexedSlices if necessary.
              g = ops.convert_to_tensor_or_indexed_slices(g)
            except TypeError:
              raise TypeError(
                  "Gradient must be convertible to a Tensor"
                  " or IndexedSlices, or None: %s" % g)
            if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
              raise TypeError(
                  "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
          p = _get_processor(v)
          converted_grads_and_vars.append((g, v, p))

        converted_grads_and_vars = tuple(converted_grads_and_vars)
        grad_list = [g for g, v, _ in converted_grads_and_vars if g is not None]
        var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
        if not var_list:
          raise ValueError("No gradients provided for any variable: %s." %
                           ([str(v) for _, v, _ in converted_grads_and_vars],))
        with ops.init_scope():
          self._create_slots(var_list)
        
        # Imp function
        compGV = self.CDGrads(converted_grads_and_vars, global_step)

        update_ops = []
        with ops.name_scope(name, self._name, skip_on_eager=False) as name:
          self._prepare()

          for grad, var, processor, val in compGV:
            if grad is None:
              continue
            # We colocate all ops created in _apply_dense or _apply_sparse
            # on the same device as the variable.
            # TODO(apassos): figure out how to get the variable name here.
            if (context.executing_eagerly() or
                resource_variable_ops.is_resource_variable(var)
                and not var._in_graph_mode):  # pylint: disable=protected-access
              scope_name = ""
            else:
              scope_name = var.op.name
            with ops.name_scope(
                "update_" + scope_name,
                skip_on_eager=False), ops.colocate_with(var):

              if self.epochStart and False:
                update_ops.append(self.set_slot(var, "epoch_var", val))
              
             
              update_ops.append(var.assign(val))
              

            
          
          if global_step is None:
            apply_updates = self._finish(update_ops, name)
          else:
            with ops.control_dependencies([self._finish(update_ops, "update")]):
              with ops.colocate_with(global_step):
                if isinstance(global_step, resource_variable_ops.BaseResourceVariable):
                  
                  apply_updates = resource_variable_ops.assign_add_variable_op( global_step.handle,
                                                                                ops.convert_to_tensor(1, dtype=global_step.dtype),
                                                                                name=name)
                else:
                  apply_updates = state_ops.assign_add(global_step, 1, name=name)

          if not context.executing_eagerly():
            if isinstance(apply_updates, ops.Tensor):
              apply_updates = apply_updates.op
            train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
            if apply_updates not in train_op:
              train_op.append(apply_updates)

          if self.epochStart:
            self.epochStart = False
          return apply_updates
Esempio n. 18
0
def _create_keras_history_helper(tensors, processed_ops, created_layers):
  """Helper method for `create_keras_history`.

  Args:
    tensors: A structure of Tensors for which to create Keras metadata.
    processed_ops: Set. TensorFlow operations that have already been wrapped in
      `TensorFlowOpLayer` instances.
    created_layers: List. The `TensorFlowOpLayer` instances created.

  Returns:
    Tuple. First element is the updated set of TensorFlow Operations that
    have been wrapped in `TensorFlowOpLayer` instances. Second element is
    a list of the `TensorFlowOpLayer` instances created.
  """
  if ops.executing_eagerly_outside_functions():
    raise ValueError(
        '`create_keras_history` should only be called if eager is disabled!')
  # Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
  # Cannot be imported at top because of circular dependencies.
  # TODO(omalleyt): Resolve circular dependency.
  from tensorflow.python.keras.engine import base_layer  # pylint: disable=g-import-not-at-top
  tensor_list = nest.flatten(tensors)
  sparse_ops = []
  ragged_tensors = []
  for tensor in tensor_list:
    if getattr(tensor, '_keras_history', None) is not None:
      continue
    if isinstance(
        tensor, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
      sparse_ops.append(tensor.op)
      continue
    if tf_utils.is_ragged(tensor):
      # Ragged tensors don't have an op property
      ragged_tensors.append(tensor)
      continue
    op = tensor.op  # The Op that created this Tensor.
    if op not in processed_ops:
      # Recursively set `_keras_history`.
      op_inputs = list(op.inputs)
      constants = {}
      layer_inputs = []
      for i, op_input in enumerate(op_inputs):
        if uses_keras_history(op_input):
          layer_inputs.append(op_input)
        else:
          # Treat any value not originating from a `keras.Input` as
          # a constant. Variables cannot be supported.
          ds_with_session = (
              distribution_strategy_context.in_cross_replica_context() and
              not ops.executing_eagerly_outside_functions())
          using_xla = control_flow_util.GraphOrParentsInXlaContext(
              ops.get_default_graph())
          if ds_with_session or using_xla or _UNSAFE_GRAPH_OP_LAYER_CREATION:
            # In Legacy Graph mode, evaluating here makes Session be
            # configured improperly. The downside of this is that saving
            # via `get_config` breaks, but SavedModel still works.
            constants[i] = op_input
          else:
            with ops.init_scope():
              constants[i] = backend.function([], op_input)([])
      layer_inputs = unnest_if_single_tensor(layer_inputs)
      processed_ops, created_layers = _create_keras_history_helper(
          layer_inputs, processed_ops, created_layers)
      name = op.name
      node_def = op.node_def.SerializeToString()
      op_layer = base_layer.TensorFlowOpLayer(
          node_def, constants=constants, name=name)
      created_layers.append(op_layer)
      op_layer._set_connectivity_metadata(  # pylint: disable=protected-access
          args=(layer_inputs,),
          kwargs={},
          outputs=op.outputs)
      processed_ops.update([op])
  if sparse_ops or ragged_tensors:
    lambda_example = """
    weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
    output = tf.keras.layers.Lambda(weights_mult)(input)
    """
    raise ValueError(
        'Tensorflow ops that generate ragged or sparse tensor '
        'outputs are currently not supported by Keras automatic '
        'op wrapping. Please wrap these ops in a Lambda layer: '
        '\n\n```\n{example}\n```\n'
        'Sparse ops encountered: {sparse_ops}\n'
        'Ragged tensors encountered: {ragged_tensors}\n'.format(
            example=lambda_example,
            sparse_ops=str(sparse_ops),
            ragged_tensors=str(ragged_tensors)))
  return processed_ops, created_layers
Esempio n. 19
0
 def apply_gradients(self, grads_and_vars, name=None):
   if distribution_strategy_context.in_cross_replica_context():
     raise ValueError('apply_gradients() must be called in a replica context.')
   grads_and_vars = tuple(grads_and_vars)
   return distribution_strategy_context.get_replica_context().merge_call(
       self._apply_gradients_cross_replica, args=(grads_and_vars, name))
 def replica_fn(input_tensor):
   # Within `replica_fn`, it has to be in a replica context.
   self.assertFalse(
       distribution_strategy_context.in_cross_replica_context())
   return input_tensor + v, input_tensor - v
Esempio n. 21
0
  def call(self, inputs, training=None):
    if training is None:
      training = K.learning_phase()

    if self.virtual_batch_size is not None:
      # Virtual batches (aka ghost batches) can be simulated by reshaping the
      # Tensor and reusing the existing batch norm implementation
      original_shape = [-1] + inputs.shape.as_list()[1:]
      expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

      # Will cause errors if virtual_batch_size does not divide the batch size
      inputs = array_ops.reshape(inputs, expanded_shape)

      def undo_virtual_batching(outputs):
        outputs = array_ops.reshape(outputs, original_shape)
        return outputs

    if self.fused:
      outputs = self._fused_batch_norm(inputs, training=training)
      if self.virtual_batch_size is not None:
        # Currently never reaches here since fused_batch_norm does not support
        # virtual batching
        outputs = undo_virtual_batching(outputs)
      return outputs

    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.shape
    ndims = len(input_shape)
    reduction_axes = [i for i in range(ndims) if i not in self.axis]
    if self.virtual_batch_size is not None:
      del reduction_axes[1]     # Do not reduce along virtual batch dim

    # Broadcasting only necessary for single-axis batch norm where the axis is
    # not the last dimension
    broadcast_shape = [1] * ndims
    broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
    def _broadcast(v):
      if (v is not None and len(v.shape) != ndims and
          reduction_axes != list(range(ndims - 1))):
        return array_ops.reshape(v, broadcast_shape)
      return v

    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

    def _compose_transforms(scale, offset, then_scale, then_offset):
      if then_scale is not None:
        scale *= then_scale
        offset *= then_scale
      if then_offset is not None:
        offset += then_offset
      return (scale, offset)

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = tf_utils.constant_value(training)
    if training_value is not False:
      if self.adjustment:
        adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
        # Adjust only during training.
        adj_scale = tf_utils.smart_cond(training,
                                        lambda: adj_scale,
                                        lambda: array_ops.ones_like(adj_scale))
        adj_bias = tf_utils.smart_cond(training,
                                       lambda: adj_bias,
                                       lambda: array_ops.zeros_like(adj_bias))
        scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)

      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
      mean, variance = self._moments(
          math_ops.cast(inputs, self._param_dtype),
          reduction_axes,
          keep_dims=keep_dims)

      moving_mean = self.moving_mean
      moving_variance = self.moving_variance

      mean = tf_utils.smart_cond(training,
                                 lambda: mean,
                                 lambda: moving_mean)
      variance = tf_utils.smart_cond(training,
                                     lambda: variance,
                                     lambda: moving_variance)

      if self.virtual_batch_size is not None:
        # This isn't strictly correct since in ghost batch norm, you are
        # supposed to sequentially update the moving_mean and moving_variance
        # with each sub-batch. However, since the moving statistics are only
        # used during evaluation, it is more efficient to just update in one
        # step and should not make a significant difference in the result.
        new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
        new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True)
      else:
        new_mean, new_variance = mean, variance

      if self.renorm:
        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
            new_mean, new_variance, training)
        # When training, the normalized values (say, x) will be transformed as
        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
        # = x * (r * gamma) + (d * gamma + beta) with renorm.
        r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
        d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
        scale, offset = _compose_transforms(r, d, scale, offset)

      if distribution_strategy_context.in_cross_replica_context():
        strategy = distribution_strategy_context.get_strategy()

        def _do_update(var, value):
          """Compute the updates for mean and variance."""
          return strategy.extended.update(
              var, self._assign_moving_average, (value, self.momentum),
              group=False)
        # We need to unwrap the moving_mean or moving_variance in the case of
        # training being false to match the output of true_fn and false_fn
        # in the smart cond.
        def mean_update():
          true_branch = lambda: _do_update(self.moving_mean, new_mean)
          false_branch = lambda: strategy.unwrap(self.moving_mean)
          return tf_utils.smart_cond(training, true_branch, false_branch)

        def variance_update():
          return tf_utils.smart_cond(
              training, lambda: _do_update(self.moving_variance, new_variance),
              lambda: strategy.unwrap(self.moving_variance))
      else:
        def _do_update(var, value):
          """Compute the updates for mean and variance."""
          return self._assign_moving_average(var, value, self.momentum)

        def mean_update():
          true_branch = lambda: _do_update(self.moving_mean, new_mean)
          false_branch = lambda: self.moving_mean
          return tf_utils.smart_cond(training, true_branch, false_branch)

        def variance_update():
          true_branch = lambda: _do_update(self.moving_variance, new_variance)
          false_branch = lambda: self.moving_variance
          return tf_utils.smart_cond(training, true_branch, false_branch)

      self.add_update(mean_update, inputs=True)
      self.add_update(variance_update, inputs=True)

    else:
      mean, variance = self.moving_mean, self.moving_variance

    mean = math_ops.cast(mean, inputs.dtype)
    variance = math_ops.cast(variance, inputs.dtype)
    if offset is not None:
      offset = math_ops.cast(offset, inputs.dtype)
    if scale is not None:
      scale = math_ops.cast(scale, inputs.dtype)
    # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
    # math in float16 hurts validation accuracy of popular models like resnet.
    outputs = nn.batch_normalization(inputs,
                                     _broadcast(mean),
                                     _broadcast(variance),
                                     offset,
                                     scale,
                                     self.epsilon)
    # If some components of the shape got lost due to adjustments, fix that.
    outputs.set_shape(input_shape)

    if self.virtual_batch_size is not None:
      outputs = undo_virtual_batching(outputs)
    return outputs
  def testVariableCaching(self):
    self.assertFalse(distribution_strategy_context.in_cross_replica_context())
    with self.strategy.scope():
      self.assertTrue(distribution_strategy_context.in_cross_replica_context())
      v = variables.Variable(
          initial_value=1.,
          aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA)

      # Test read value inside caching scope
      with distribute_utils.cache_variable_reads():
        v.read_value()  # Reads value 1.0
        v.assign(constant_op.constant(5.0))  # v changes to 5.0
        self.assertEqual(v.read_value(), 1.0)  # should be cached 1.0 value.

      # Reset v to 2.0
      v.assign(2.0)

      # Test convert to tensor value inside caching scope
      with distribute_utils.cache_variable_reads():
        t = v * 3.0
        self.assertEqual(t, 6.0)
        v.assign(3.0)
        t1 = v * 3.0
        self.assertEqual(t1, 6.0)  # should be cached 2.0 * 3.0 value.

      # Reset v to 1.0
      v.assign(1.0)

      # Verify caching scope inside tf.function
      @def_function.function
      def worker_fn():
        with distribute_utils.cache_variable_reads():
          def replica_fn():
            t = v.read_value()  # Reads value 1.0
            v.assign(constant_op.constant(5.0))  # v changes to 5.0
            t = v.read_value()  # should return 1.0
            return t  # Should be 1.0 instead of 5.0

          return self.strategy.run(replica_fn)

      result = self.coordinator.schedule(worker_fn)
      result = result.fetch()
      expected_result = 1.
      self.assertEqual(result, expected_result)

      # Verify that v.read_value works as expected outside of scope.
      v.assign(4.0)
      self.assertEqual(v.read_value(), 4.0)

      v.assign(constant_op.constant(2.0))  # v changes to 2.0
      # Check with scope outside of tf function and check that cache is reset
      @def_function.function
      def worker_fn1():
        def replica_fn():
          t = v.read_value()  # Reads value 2.0 ==> Should be cached
          v.assign(constant_op.constant(5.0))  # v changes to 5.0
          t = v.read_value()  # should return cached value 2.0
          return t  # Should be 2.0 instead of 5.0

        return self.strategy.run(replica_fn)

      with distribute_utils.cache_variable_reads():
        result = self.coordinator.schedule(worker_fn1)
      result = result.fetch()
      expected_result = 2.
      self.assertEqual(result, expected_result)

    # Verify scope nesting is not permitted.
    with self.assertRaises(ValueError):
      with distribute_utils.cache_variable_reads():
        with distribute_utils.cache_variable_reads():
          v.read_value()
Esempio n. 23
0
    def apply_gradients(self,
                        grads_and_vars,
                        name=None,
                        experimental_aggregate_gradients=True):
        """Apply gradients to variables.

    Only the last two lines are different from optimizer_v2.OptimizerV2.

    Args:
      grads_and_vars: List of (gradient, variable) pairs.
      name: Optional name for the returned operation. Default to the name passed
        to the `Optimizer` constructor.
      experimental_aggregate_gradients: Whether to sum gradients from different
        replicas in the presense of `tf.distribute.Strategy`. If False, it's
        user responsibility to aggregate the gradients. Default to True.

    Returns:
      An `Operation` that applies the specified gradients. The `iterations`
      will be automatically increased by 1.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
      RuntimeError: If called in cross-replica context.
    """
        # pylint: disable=protected-access
        grads_and_vars = optimizer_v2._filter_grads(grads_and_vars)
        # pylint: enable=protected-access
        var_list = [v for (_, v) in grads_and_vars]

        with backend.name_scope(self._name):
            # Create iteration if necessary.
            with ops.init_scope():
                self._create_all_weights(var_list)

            if not grads_and_vars:
                # Distribution strategy does not support reducing an empty list of
                # gradients
                return control_flow_ops.no_op()

            if distribute_ctx.in_cross_replica_context():
                raise RuntimeError(
                    "`apply_gradients() cannot be called in cross-replica context. "
                    "Use `tf.distribute.Strategy.run` to enter replica "
                    "context.")

            strategy = distribute_ctx.get_strategy()
            if (not experimental_aggregate_gradients and strategy
                    and isinstance(
                        strategy.extended, parameter_server_strategy.
                        ParameterServerStrategyExtended)):
                raise NotImplementedError(
                    "`experimental_aggregate_gradients=False is not supported for "
                    "ParameterServerStrategy and CentralStorageStrategy")

            apply_state = self._prepare(var_list)
            if experimental_aggregate_gradients:
                reduced_grads = self._aggregate_gradients(grads_and_vars)
                var_list = [v for _, v in grads_and_vars]
                grads_and_vars = list(zip(reduced_grads, var_list))

            self._distributed_apply(None, grads_and_vars, name, apply_state)
            return self._iterations.assign_add(1, read_value=False)
Esempio n. 24
0
 def _as_graph_element(self):
     # pylint: disable=protected-access
     with ds_context.enter_or_assert_strategy(self._distribute_strategy):
         if ds_context.in_cross_replica_context():
             return ops.convert_to_tensor(self._get_cross_replica())
     return self._get()._as_graph_element()
Esempio n. 25
0
  def apply_gradients(self,
                      grads_and_vars,
                      name=None,
                      experimental_aggregate_gradients=True):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    The method sums gradients from all replicas in the presence of
    `tf.distribute.Strategy` by default. You can aggregate gradients yourself by
    passing `experimental_aggregate_gradients=False`.

    Example:

    ```python
    grads = tape.gradient(loss, vars)
    grads = tf.distribute.get_replica_context().all_reduce('sum', grads)
    # Processing aggregated gradients.
    optimizer.apply_gradients(zip(grads, vars),
        experimental_aggregate_gradients=False)

    ```

    Args:
      grads_and_vars: List of (gradient, variable) pairs.
      name: Optional name for the returned operation. Default to the name passed
        to the `Optimizer` constructor.
      experimental_aggregate_gradients: Whether to sum gradients from different
        replicas in the presense of `tf.distribute.Strategy`. If False, it's
        user responsibility to aggregate the gradients. Default to True.

    Returns:
      An `Operation` that applies the specified gradients. The `iterations`
      will be automatically increased by 1.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
    """
    grads_and_vars = _filter_grads(grads_and_vars)
    var_list = [v for (_, v) in grads_and_vars]

    with backend.name_scope(self._name):
      # Create iteration if necessary.
      with ops.init_scope():
        _ = self.iterations
        self._create_hypers()
        self._create_slots(var_list)

      if not grads_and_vars:
        # Distribution strategy does not support reducing an empty list of
        # gradients
        return control_flow_ops.no_op()

      if distribute_ctx.in_cross_replica_context():
        raise RuntimeError(
            "`apply_gradients() cannot be called in cross-replica context. "
            "Use `tf.distribute.Strategy.run` to enter replica "
            "context.")

      strategy = distribute_ctx.get_strategy()
      if (not experimental_aggregate_gradients and strategy and isinstance(
          strategy.extended,
          parameter_server_strategy.ParameterServerStrategyExtended)):
        raise NotImplementedError(
            "`experimental_aggregate_gradients=False is not supported for "
            "ParameterServerStrategy and CentralStorageStrategy")

      apply_state = self._prepare(var_list)
      if experimental_aggregate_gradients:
        reduced_grads = self._aggregate_gradients(grads_and_vars)
        var_list = [v for _, v in grads_and_vars]
        grads_and_vars = list(zip(reduced_grads, var_list))
      return distribute_ctx.get_replica_context().merge_call(
          functools.partial(self._distributed_apply, apply_state=apply_state),
          args=(grads_and_vars,),
          kwargs={
              "name": name,
          })
Esempio n. 26
0
def _create_keras_history_helper(tensors, processed_ops, created_layers):
    """Helper method for `create_keras_history`.

  Arguments:
    tensors: A structure of Tensors for which to create Keras metadata.
    processed_ops: Set. TensorFlow operations that have already been wrapped in
      `TensorFlowOpLayer` instances.
    created_layers: List. The `TensorFlowOpLayer` instances created.

  Returns:
    Tuple. First element is the updated set of TensorFlow Operations that
    have been wrapped in `TensorFlowOpLayer` instances. Second element is
    a list of the `TensorFlowOpLayer` instances created.
  """
    # Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
    # Cannot be imported at top because of circular dependencies.
    # TODO(omalleyt): Resolve circular dependency.
    from tensorflow.python.keras.engine import base_layer  # pylint: disable=g-import-not-at-top
    tensor_list = nest.flatten(tensors)
    for tensor in tensor_list:
        if getattr(tensor, '_keras_history', None) is not None:
            continue
        op = tensor.op  # The Op that created this Tensor.
        if op not in processed_ops:
            if op.type.startswith('Sparse'):
                lambda_example = """
        weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
        output = tf.keras.layers.Lambda(weights_mult)(input)
        """
                raise ValueError(
                    'Sparse ops are not supported with functional models with built-in '
                    'layer wrapping. Please wrap the sparse ops in a Lambda layer like'
                    ': \n{lambda_example}\n'.format(
                        lambda_example=lambda_example))

            # Recursively set `_keras_history`.
            op_inputs = list(op.inputs)
            constants = {}
            layer_inputs = []
            for i, op_input in enumerate(op_inputs):
                if uses_keras_history(op_input):
                    layer_inputs.append(op_input)
                else:
                    # Treat any value not originating from a `keras.Input` as
                    # a constant. Variables cannot be supported.
                    ds_with_session = (
                        distribution_strategy_context.in_cross_replica_context(
                        ) and not ops.executing_eagerly_outside_functions())
                    using_xla = control_flow_util.GraphOrParentsInXlaContext(
                        ops.get_default_graph())
                    if ds_with_session or using_xla:
                        # In Legacy Graph mode, evaluating here makes Session be
                        # configured improperly. The downside of this is that saving
                        # via `get_config` breaks, but SavedModel still works.
                        constants[i] = op_input
                    else:
                        with ops.init_scope():
                            if ops.executing_eagerly_outside_functions():
                                constants[
                                    i] = backend.eval_in_eager_or_function(
                                        op_input)
                            else:
                                constants[i] = backend.function([],
                                                                op_input)([])
            layer_inputs = unnest_if_single_tensor(layer_inputs)
            processed_ops, created_layers = _create_keras_history_helper(
                layer_inputs, processed_ops, created_layers)
            name = op.name
            node_def = op.node_def.SerializeToString()
            op_layer = base_layer.TensorFlowOpLayer(node_def,
                                                    constants=constants,
                                                    name=name)
            created_layers.append(op_layer)
            op_layer._set_connectivity_metadata(  # pylint: disable=protected-access
                args=(layer_inputs, ),
                kwargs={},
                outputs=op.outputs)
            processed_ops.update([op])
    return processed_ops, created_layers
Esempio n. 27
0
    def _fused_batch_norm(self, inputs, training):
        """Returns the output of fused batch norm."""
        beta = self.beta if self.center else self._beta_const
        gamma = self.gamma if self.scale else self._gamma_const

        def _fused_batch_norm_training():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       epsilon=self.epsilon,
                                       data_format=self._data_format)

        def _fused_batch_norm_inference():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       mean=self.moving_mean,
                                       variance=self.moving_variance,
                                       epsilon=self.epsilon,
                                       is_training=False,
                                       data_format=self._data_format)

        output, mean, variance = tf_utils.smart_cond(
            training, _fused_batch_norm_training, _fused_batch_norm_inference)
        if not self._bessels_correction_test_only:
            # Remove Bessel's correction to be consistent with non-fused batch norm.
            # Note that the variance computed by fused batch norm is
            # with Bessel's correction.
            sample_size = math_ops.cast(
                array_ops.size(inputs) / array_ops.size(variance),
                variance.dtype)
            factor = (sample_size -
                      math_ops.cast(1.0, variance.dtype)) / sample_size
            variance *= factor

        training_value = tf_utils.constant_value(training)
        if training_value is None:
            momentum = tf_utils.smart_cond(training, lambda: self.momentum,
                                           lambda: 1.0)
        else:
            momentum = ops.convert_to_tensor(self.momentum)
        if training_value or training_value is None:
            if distribution_strategy_context.in_cross_replica_context():
                strategy = distribution_strategy_context.get_strategy()

                def mean_update():
                    return strategy.extended.update(
                        self.moving_mean, self._assign_moving_average,
                        (mean, self.momentum))

                def variance_update():
                    return strategy.extended.update(
                        self.moving_variance, self._assign_moving_average,
                        (variance, self.momentum))
            else:

                def mean_update():
                    return self._assign_moving_average(self.moving_mean, mean,
                                                       momentum)

                def variance_update():
                    return self._assign_moving_average(self.moving_variance,
                                                       variance, momentum)

            self.add_update(mean_update, inputs=True)
            self.add_update(variance_update, inputs=True)

        return output
Esempio n. 28
0
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()

        in_eager_mode = context.executing_eagerly()
        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(inputs, training=training)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                outputs = undo_virtual_batching(outputs)
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)
        if training_value is not False:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = tf_utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = tf_utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1
            mean, variance = self._moments(inputs,
                                           reduction_axes,
                                           keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(training, lambda: mean,
                                       lambda: moving_mean)
            variance = tf_utils.smart_cond(training, lambda: variance,
                                           lambda: moving_variance)

            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
                new_variance = math_ops.reduce_mean(variance,
                                                    axis=1,
                                                    keepdims=True)
            else:
                new_mean, new_variance = mean, variance

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    new_mean, new_variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)

            if distribution_strategy_context.in_cross_replica_context():
                strategy = distribution_strategy_context.get_strategy()

                def _do_update(var, value):
                    """Compute the updates for mean and variance."""
                    if in_eager_mode and not self.trainable:
                        return
                    return strategy.extended.update(
                        var,
                        self._assign_moving_average, (value, self.momentum),
                        group=False)

                # We need to unwrap the moving_mean or moving_variance in the case of
                # training being false to match the output of true_fn and false_fn
                # in the smart cond.
                mean_update = tf_utils.smart_cond(
                    training, lambda: _do_update(self.moving_mean, new_mean),
                    lambda: strategy.unwrap(self.moving_mean))
                variance_update = tf_utils.smart_cond(
                    training,
                    lambda: _do_update(self.moving_variance, new_variance),
                    lambda: strategy.unwrap(self.moving_variance))
            else:

                def _do_update(var, value):
                    """Compute the updates for mean and variance."""
                    if in_eager_mode and not self.trainable:
                        return
                    return self._assign_moving_average(var, value,
                                                       self.momentum)

                mean_update = tf_utils.smart_cond(
                    training, lambda: _do_update(self.moving_mean, new_mean),
                    lambda: self.moving_mean)
                variance_update = tf_utils.smart_cond(
                    training,
                    lambda: _do_update(self.moving_variance, new_variance),
                    lambda: self.moving_variance)
            if not context.executing_eagerly():
                self.add_update(mean_update, inputs=True)
                self.add_update(variance_update, inputs=True)

        else:
            mean, variance = self.moving_mean, self.moving_variance

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)
        return outputs
 def apply_gradients(self, grads_and_vars, name=None):
   if distribution_strategy_context.in_cross_replica_context():
     raise ValueError('apply_gradients() must be called in a replica context.')
   return distribution_strategy_context.get_replica_context().merge_call(
       self._apply_gradients_cross_replica, args=(grads_and_vars, name))