예제 #1
0
  def _create_non_slot_variable(self, initial_value, name, colocate_with):
    """Add an extra variable, not associated with a slot."""
    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
    eager = context.executing_eagerly()
    graph = None if eager else colocate_with.graph

    key = (name, graph)
    v = self._non_slot_dict.get(key, None)
    if v is None:
      self._maybe_initialize_checkpointable()
      distribution_strategy = distribute_lib.get_distribution_strategy()
      with distribution_strategy.colocate_vars_with(colocate_with):
        if eager:
          restored_initial_value = self._preload_simple_restoration(
              name=name, shape=None)
          if restored_initial_value is not None:
            initial_value = restored_initial_value
        v = variable_scope.variable(initial_value, name=name, trainable=False)
      # Restore this variable by name if necessary, but don't add a
      # Checkpointable dependency. Optimizers return the current graph's
      # non-slot variables from _checkpoint_dependencies explicitly rather
      # than unconditionally adding dependencies (since there may be multiple
      # non-slot variables with the same name in different graphs, trying to
      # save all of them would result in errors).
      self._handle_deferred_dependencies(name=name, checkpointable=v)
      self._non_slot_dict[key] = v

    return v
예제 #2
0
  def _create_non_slot_variable(self, initial_value, name, colocate_with):
    """Add an extra variable, not associated with a slot."""
    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
    eager = context.executing_eagerly()
    graph = None if eager else colocate_with.graph

    key = (name, graph)
    v = self._non_slot_dict.get(key, None)
    if v is None:
      self._maybe_initialize_checkpointable()
      distribution_strategy = distribute_lib.get_distribution_strategy()
      with distribution_strategy.colocate_vars_with(colocate_with):
        if eager:
          restored_initial_value = self._preload_simple_restoration(
              name=name, shape=None)
          if restored_initial_value is not None:
            initial_value = restored_initial_value
        v = variable_scope.variable(initial_value, name=name, trainable=False)
      # Restore this variable by name if necessary, but don't add a
      # Checkpointable dependency. Optimizers return the current graph's
      # non-slot variables from _checkpoint_dependencies explicitly rather
      # than unconditionally adding dependencies (since there may be multiple
      # non-slot variables with the same name in different graphs, trying to
      # save all of them would result in errors).
      self._handle_deferred_dependencies(name=name, checkpointable=v)
      self._non_slot_dict[key] = v

    return v
예제 #3
0
def _assert_in_default_state(t):
    t.assertIs(distribute._default_tower_context,
               distribute.get_tower_context())
    t.assertIs(None, distribute.get_cross_tower_context())
    t.assertIs(distribute._default_distribution_strategy,
               distribute.get_distribution_strategy())
    t.assertFalse(distribute.has_distribution_strategy())
예제 #4
0
 def merge_fn(dist, s):
   self.assertIs(distribute._default_distribution_strategy, dist)
   self.assertIs(None, distribute.get_tower_context())
   self.assertIs(dist, distribute.get_cross_tower_context())
   self.assertIs(dist, distribute.get_distribution_strategy())
   self.assertFalse(distribute.has_distribution_strategy())
   return "foo_" + s
예제 #5
0
def _assert_in_default_state(t):
  t.assertIs(distribute._default_tower_context,
             distribute.get_tower_context())
  t.assertIs(None, distribute.get_cross_tower_context())
  t.assertIs(distribute._default_distribution_strategy,
             distribute.get_distribution_strategy())
  t.assertFalse(distribute.has_distribution_strategy())
예제 #6
0
def create_slot(primary, val, name, colocate_with_primary=True):
  """Create a slot initialized to the given value.

  The type of the slot is determined by the given value.

  Args:
    primary: The primary `Variable` or `Tensor`.
    val: A `Tensor` specifying the initial value of the slot.
    name: Name to use for the slot variable.
    colocate_with_primary: Boolean.  If True the slot is located
      on the same device as `primary`.

  Returns:
    A `Variable` object.
  """
  # Scope the slot name in the namespace of the primary variable.
  # Set "primary.op.name + '/' + name" as default name, so the scope name of
  # optimizer can be shared when reuse is True. Meanwhile when reuse is False
  # and the same name has been previously used, the scope name will add '_N'
  # as suffix for unique identifications.
  validate_shape = val.get_shape().is_fully_defined()
  if context.executing_eagerly():
    prefix = primary._shared_name  # pylint: disable=protected-access
  else:
    prefix = primary.op.name
  with variable_scope.variable_scope(None, prefix + "/" + name):
    if colocate_with_primary:
      distribution_strategy = distribute_lib.get_distribution_strategy()
      with distribution_strategy.colocate_vars_with(primary):
        return _create_slot_var(primary, val, "", validate_shape, None, None)
    else:
      return _create_slot_var(primary, val, "", validate_shape, None, None)
예제 #7
0
 def merge_fn(dist, s):
     self.assertIs(distribute._default_distribution_strategy, dist)
     self.assertIs(None, distribute.get_tower_context())
     self.assertIs(dist, distribute.get_cross_tower_context())
     self.assertIs(dist, distribute.get_distribution_strategy())
     self.assertFalse(distribute.has_distribution_strategy())
     return "foo_" + s
예제 #8
0
  def _assign_func(self, *args, **kwargs):
    f = kwargs.pop("f")
    if distribute_lib.get_cross_tower_context():
      update_device = distribute_lib.get_update_device()
      # We are calling update on the mirrored variable in cross tower context.
      if update_device is not None:
        # We are calling an assign function on the mirrored variable in cross
        # tower context.
        v = self.get(device=update_device)
        return f(v, *args, **kwargs)

      return distribute_lib.get_distribution_strategy().update(
          self, f, *args, **kwargs)
    else:
      _assert_tower_context()
      # We are calling an assign function on the mirrored variable in tower
      # context.
      # We reduce the value we want to assign/add/sub. More details about how we
      # handle the different use cases can be found in the _reduce method.
      # We call the function on each of the mirrored variables with the reduced
      # value.
      if self._aggregation == vs.VariableAggregation.NONE:
        raise ValueError("You must specify an aggregation method to update a "
                         "MirroredVariable in Tower Context.")

      def merge_fn(strategy, value, *other_args, **other_kwargs):
        return strategy.update(
            self, f,
            strategy.reduce(
                aggregation=self._aggregation, value=value, destinations=self),
            *other_args, **other_kwargs)

      return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
                                                           **kwargs)
예제 #9
0
 def run_fn():
   tower_context = distribute.get_tower_context()
   self.assertTrue(tower_context is not None)
   self.assertIs(None, distribute.get_cross_tower_context())
   self.assertTrue(distribute.has_distribution_strategy())
   self.assertIs(dist, distribute.get_distribution_strategy())
   self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
   self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
예제 #10
0
 def run_fn():
     tower_context = distribute.get_tower_context()
     self.assertTrue(tower_context is not None)
     self.assertIs(None, distribute.get_cross_tower_context())
     self.assertTrue(distribute.has_distribution_strategy())
     self.assertIs(dist, distribute.get_distribution_strategy())
     self.assertEqual("foo",
                      tower_context.merge_call(None, test_arg="foo"))
     self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
예제 #11
0
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribute.get_tower_context())
     self.assertIs(dist, distribute.get_cross_tower_context())
     self.assertTrue(distribute.has_distribution_strategy())
     self.assertIs(dist, distribute.get_distribution_strategy())
     self.assertEqual("baz", variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
예제 #12
0
 def testScope(self):
     _assert_in_default_state(self)
     dist = _TestStrategy()
     with dist.scope():
         self.assertIs(None, distribute.get_tower_context())
         self.assertIs(dist, distribute.get_cross_tower_context())
         self.assertTrue(distribute.has_distribution_strategy())
         self.assertIs(dist, distribute.get_distribution_strategy())
         self.assertEqual("baz", variable_scope.variable(1.0, name="baz"))
     _assert_in_default_state(self)
예제 #13
0
 def run_fn():
   tower_context = distribute.get_tower_context()
   self.assertTrue(tower_context is not None)
   self.assertIs(None, distribute.get_cross_tower_context())
   self.assertTrue(distribute.has_distribution_strategy())
   self.assertIs(dist, distribute.get_distribution_strategy())
   self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
예제 #14
0
 def run_fn():
     tower_context = distribute.get_tower_context()
     self.assertTrue(tower_context is not None)
     self.assertIs(None, distribute.get_cross_tower_context())
     self.assertTrue(distribute.has_distribution_strategy())
     self.assertIs(dist, distribute.get_distribution_strategy())
     self.assertEqual("foo",
                      tower_context.merge_call(None, test_arg="foo"))
     expected_value = _get_test_variable(
         "bar", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="bar"))
예제 #15
0
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribute.get_tower_context())
     self.assertIs(dist, distribute.get_cross_tower_context())
     self.assertTrue(distribute.has_distribution_strategy())
     self.assertIs(dist, distribute.get_distribution_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
예제 #16
0
 def testScope(self):
     _assert_in_default_state(self)
     dist = _TestStrategy()
     with dist.scope():
         self.assertIs(None, distribute.get_tower_context())
         self.assertIs(dist, distribute.get_cross_tower_context())
         self.assertTrue(distribute.has_distribution_strategy())
         self.assertIs(dist, distribute.get_distribution_strategy())
         expected_value = _get_test_variable(
             "baz", variable_scope.VariableSynchronization.AUTO,
             variable_scope.VariableAggregation.NONE)
         self.assertDictEqual(expected_value,
                              variable_scope.variable(1.0, name="baz"))
     _assert_in_default_state(self)
예제 #17
0
  def set_last_step_output(self, name, output,
                           aggregation=variables_lib.VariableAggregation.NONE):
    """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      aggregation: Aggregation method to use to aggregate outputs from multiple
        towers. Required if `set_last_step_output` is called in a tower context.
        Optional in cross_tower_context.
        When present, the outputs from all the towers are aggregated using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and aggregation is set, output
        must be a `PerDevice` value.
        The aggregation method is also recorded in a dictionary
        `_last_step_outputs_aggregations` for later interpreting of the
        outputs as already reduced or not.

    """
    if distribute_lib.get_cross_tower_context():
      self._last_step_outputs_aggregations[name] = aggregation
      if aggregation is variables_lib.VariableAggregation.NONE:
        self._last_step_outputs[name] = output
      else:
        distribution = distribute_lib.get_distribution_strategy()
        self._last_step_outputs[name] = distribution.reduce(
            aggregation, output, destinations="/device:CPU:0")
    else:
      assert aggregation is not variables_lib.VariableAggregation.NONE
      def merge_fn(distribution, value):
        self._last_step_outputs[name] = distribution.reduce(
            aggregation, value, destinations="/device:CPU:0")
        # Setting this inside the `merge_fn` because all towers share the same
        # context object, so it's more robust to set it only once (even if all
        # the towers are trying to set the same value).
        self._last_step_outputs_aggregations[name] = aggregation
      distribute_lib.get_tower_context().merge_call(merge_fn, output)
예제 #18
0
def _set_checkpoint_initializer(variable,
                                ckpt_file,
                                tensor_name,
                                slice_spec,
                                name="checkpoint_initializer"):
  """Overrides given variable's initialization op.

  Sets variable initializer to assign op that initializes variable from tensor's
  value in the checkpoint.

  Args:
    variable: `tf.Variable` object.
    ckpt_file: string, full path of the checkpoint.
    tensor_name: Name of the tensor to load from the checkpoint.
    slice_spec: Slice specification for loading partitioned tensors.
    name: Name of the operation.
  """
  base_type = variable.dtype.base_dtype
  # Do not colocate with variable since RestoreV2 op only runs on CPU and
  # colocation will force variable (and other ops that colocate with variable)
  # to be on CPU as well. It is okay to place the variable's initializer op on
  # CPU since it will only be run once at the start.
  with ops.device(variable.device), ops.device("/cpu:0"):
    restore_op = io_ops.restore_v2(
        ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]

    # TODO(priyag, allenl): Use `SaveableObject.restore` instead here.
    if resource_variable_ops.is_resource_variable(variable):
      init_op = variable.assign(restore_op, read_value=False)
      # TODO(priyag): Remove this when using `SaveableObject.restore` instead.
      if hasattr(init_op, "_index"):
        init_op = distribute_lib.get_distribution_strategy().group(init_op)
    else:
      init_op = state_ops.assign(variable, restore_op)

    # pylint:disable=protected-access
    variable._initializer_op = init_op
    restore_op.set_shape(variable.shape)
    variable._initial_value = restore_op
예제 #19
0
 def read_value(self):
   return distribute_lib.get_distribution_strategy().read_var(self)
예제 #20
0
  def build(self, input_shape):
    input_shape = tensor_shape.TensorShape(input_shape)
    if not input_shape.ndims:
      raise ValueError('Input has undefined rank:', input_shape)
    ndims = len(input_shape)

    # Convert axis to list and resolve negatives
    if isinstance(self.axis, int):
      self.axis = [self.axis]

    if not isinstance(self.axis, list):
      raise TypeError('axis must be int or list, type given: %s'
                      % type(self.axis))

    for idx, x in enumerate(self.axis):
      if x < 0:
        self.axis[idx] = ndims + x

    # Validate axes
    for x in self.axis:
      if x < 0 or x >= ndims:
        raise ValueError('Invalid axis: %d' % x)
    if len(self.axis) != len(set(self.axis)):
      raise ValueError('Duplicate axis: %s' % self.axis)

    if self.virtual_batch_size is not None:
      if self.virtual_batch_size <= 0:
        raise ValueError('virtual_batch_size must be a positive integer that '
                         'divides the true batch size of the input Tensor')
      # If using virtual batches, the first dimension must be the batch
      # dimension and cannot be the batch norm axis
      if 0 in self.axis:
        raise ValueError('When using virtual_batch_size, the batch dimension '
                         'must be 0 and thus axis cannot include 0')
      if self.adjustment is not None:
        raise ValueError('When using virtual_batch_size, adjustment cannot '
                         'be specified')

    if self.fused:
      # Currently fused batch norm doesn't support renorm. It also only supports
      # an input tensor of rank 4 and a channel dimension on axis 1 or 3.
      # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
      # output back to its original shape accordingly.
      self.fused = (not self.renorm and
                    ndims == 4 and
                    self.axis in [[1], [3]] and
                    self.virtual_batch_size is None and
                    self.adjustment is None)
      # TODO(chrisying): fused batch norm is currently not supported for
      # multi-axis batch norm and by extension virtual batches. In some cases,
      # it might be possible to use fused batch norm but would require reshaping
      # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
      # particularly tricky. A compromise might be to just support the most
      # common use case (turning 5D w/ virtual batch to NCHW)

    if self.fused:
      if self.axis == [1]:
        self._data_format = 'NCHW'
      elif self.axis == [3]:
        self._data_format = 'NHWC'
      else:
        raise ValueError('Unsupported axis, fused batch norm only supports '
                         'axis == [1] or axis == [3]')

    # Raise parameters of fp16 batch norm to fp32
    if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16:
      param_dtype = dtypes.float32
    else:
      param_dtype = self.dtype or dtypes.float32

    axis_to_dim = {x: input_shape[x].value for x in self.axis}
    for x in axis_to_dim:
      if axis_to_dim[x] is None:
        raise ValueError('Input has undefined `axis` dimension. Input shape: ',
                         input_shape)
    self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)

    if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
      # Single axis batch norm (most common/default use-case)
      param_shape = (list(axis_to_dim.values())[0],)
    else:
      # Parameter shape is the original shape but with 1 in all non-axis dims
      param_shape = [axis_to_dim[i] if i in axis_to_dim
                     else 1 for i in range(ndims)]
      if self.virtual_batch_size is not None:
        # When using virtual batches, add an extra dim at index 1
        param_shape.insert(1, 1)
        for idx, x in enumerate(self.axis):
          self.axis[idx] = x + 1      # Account for added dimension

    if self.scale:
      self.gamma = self.add_weight(
          name='gamma',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.gamma_initializer,
          regularizer=self.gamma_regularizer,
          constraint=self.gamma_constraint,
          trainable=True)
    else:
      self.gamma = None
      if self.fused:
        self._gamma_const = array_ops.constant(
            1.0, dtype=param_dtype, shape=param_shape)

    if self.center:
      self.beta = self.add_weight(
          name='beta',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.beta_initializer,
          regularizer=self.beta_regularizer,
          constraint=self.beta_constraint,
          trainable=True)
    else:
      self.beta = None
      if self.fused:
        self._beta_const = array_ops.constant(
            0.0, dtype=param_dtype, shape=param_shape)

    try:
      # Disable variable partitioning when creating the moving mean and variance
      if hasattr(self, '_scope') and self._scope:
        partitioner = self._scope.partitioner
        self._scope.set_partitioner(None)
      else:
        partitioner = None
      self.moving_mean = self._add_tower_local_variable(
          name='moving_mean',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.moving_mean_initializer,
          trainable=False)

      self.moving_variance = self._add_tower_local_variable(
          name='moving_variance',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.moving_variance_initializer,
          trainable=False)

      if self.renorm:
        # Create variables to maintain the moving mean and standard deviation.
        # These are used in training and thus are different from the moving
        # averages above. The renorm variables are colocated with moving_mean
        # and moving_variance.
        # NOTE: below, the outer `with device` block causes the current device
        # stack to be cleared. The nested ones use a `lambda` to set the desired
        # device and ignore any devices that may be set by the custom getter.
        def _renorm_variable(name, shape):
          var = self._add_tower_local_variable(
              name=name,
              shape=shape,
              dtype=param_dtype,
              initializer=init_ops.zeros_initializer(),
              trainable=False)
          return var

        with distribute_lib.get_distribution_strategy().colocate_vars_with(
            self.moving_mean):
          self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
          self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
        # We initialize renorm_stddev to 0, and maintain the (0-initialized)
        # renorm_stddev_weight. This allows us to (1) mix the average
        # stddev with the minibatch stddev early in training, and (2) compute
        # the unbiased average stddev by dividing renorm_stddev by the weight.
        with distribute_lib.get_distribution_strategy().colocate_vars_with(
            self.moving_variance):
          self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
          self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight',
                                                       ())
    finally:
      if partitioner:
        self._scope.set_partitioner(partitioner)
    self.built = True
예제 #21
0
파일: values.py 프로젝트: Utsal20/poGANmon
 def tensor():
   return distribute_lib.get_distribution_strategy().fetch(
       tower_local_variable)
예제 #22
0
  def compute_gradients(self, loss, var_list=None,
                        gate_gradients=GATE_OP,
                        aggregation_method=None,
                        colocate_gradients_with_ops=False,
                        grad_loss=None):
    """Compute gradients of `loss` for the variables in `var_list`.

    This is the first part of `minimize()`.  It returns a list
    of (gradient, variable) pairs where "gradient" is the gradient
    for "variable".  Note that "gradient" can be a `Tensor`, an
    `IndexedSlices`, or `None` if there is no gradient for the
    given variable.

    Args:
      loss: A Tensor containing the value to minimize or a callable taking
        no arguments which returns the value to minimize. When eager execution
        is enabled it must be a callable.
      var_list: Optional list or tuple of `tf.Variable` to update to minimize
        `loss`.  Defaults to the list of variables collected in the graph
        under the key `GraphKeys.TRAINABLE_VARIABLES`.
      gate_gradients: How to gate the computation of gradients.  Can be
        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
      aggregation_method: Specifies the method used to combine gradient terms.
        Valid values are defined in the class `AggregationMethod`.
      colocate_gradients_with_ops: If True, try colocating gradients with
        the corresponding op.
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.

    Returns:
      A list of (gradient, variable) pairs. Variable is always present, but
      gradient can be `None`.

    Raises:
      TypeError: If `var_list` contains anything else than `Variable` objects.
      ValueError: If some arguments are invalid.
      RuntimeError: If called with eager execution enabled and `loss` is
        not callable.

    @compatibility(eager)
    When eager execution is enabled, `gate_gradients`, `aggregation_method`,
    and `colocate_gradients_with_ops` are ignored.
    @end_compatibility
    """
    if callable(loss):
      with backprop.GradientTape() as tape:
        if var_list is not None:
          tape.watch(var_list)
        loss_value = loss()

        # Scale loss if using a "mean" loss reduction and multiple towers.
        # Have to be careful to call distribute_lib.get_loss_reduction()
        # *after* loss() is evaluated, so we know what loss reduction it uses.
        # TODO(josh11b): Test that we handle weight decay in a reasonable way.
        if distribute_lib.get_loss_reduction() == "mean":
          num_towers = distribute_lib.get_distribution_strategy().num_towers
          if num_towers > 1:
            loss_value *= (1. / num_towers)

      if var_list is None:
        var_list = tape.watched_variables()
      grads = tape.gradient(loss_value, var_list, grad_loss)
      return list(zip(grads, var_list))

    # Non-callable/Tensor loss case
    if context.executing_eagerly():
      raise RuntimeError(
          "`loss` passed to Optimizer.compute_gradients should "
          "be a function when eager execution is enabled.")

    # Scale loss if using a "mean" loss reduction and multiple towers.
    if distribute_lib.get_loss_reduction() == "mean":
      num_towers = distribute_lib.get_distribution_strategy().num_towers
      if num_towers > 1:
        loss *= (1. / num_towers)

    if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
                              Optimizer.GATE_GRAPH]:
      raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
                       "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
                       gate_gradients)
    self._assert_valid_dtypes([loss])
    if grad_loss is not None:
      self._assert_valid_dtypes([grad_loss])
    if var_list is None:
      var_list = (
          variables.trainable_variables() +
          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
    else:
      var_list = nest.flatten(var_list)
    # pylint: disable=protected-access
    var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
    # pylint: enable=protected-access
    processors = [_get_processor(v) for v in var_list]
    if not var_list:
      raise ValueError("No variables to optimize.")
    var_refs = [p.target() for p in processors]
    grads = gradients.gradients(
        loss, var_refs, grad_ys=grad_loss,
        gate_gradients=(gate_gradients == Optimizer.GATE_OP),
        aggregation_method=aggregation_method,
        colocate_gradients_with_ops=colocate_gradients_with_ops)
    if gate_gradients == Optimizer.GATE_GRAPH:
      grads = control_flow_ops.tuple(grads)
    grads_and_vars = list(zip(grads, var_list))
    self._assert_valid_dtypes(
        [v for g, v in grads_and_vars
         if g is not None and v.dtype != dtypes.resource])
    return grads_and_vars
예제 #23
0
    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        if not input_shape.ndims:
            raise ValueError('Input has undefined rank:', input_shape)
        ndims = len(input_shape)

        # Convert axis to list and resolve negatives
        if isinstance(self.axis, int):
            self.axis = [self.axis]

        if not isinstance(self.axis, list):
            raise TypeError('axis must be int or list, type given: %s' %
                            type(self.axis))

        for idx, x in enumerate(self.axis):
            if x < 0:
                self.axis[idx] = ndims + x

        # Validate axes
        for x in self.axis:
            if x < 0 or x >= ndims:
                raise ValueError('Invalid axis: %d' % x)
        if len(self.axis) != len(set(self.axis)):
            raise ValueError('Duplicate axis: %s' % self.axis)

        if self.virtual_batch_size is not None:
            if self.virtual_batch_size <= 0:
                raise ValueError(
                    'virtual_batch_size must be a positive integer that '
                    'divides the true batch size of the input Tensor')
            # If using virtual batches, the first dimension must be the batch
            # dimension and cannot be the batch norm axis
            if 0 in self.axis:
                raise ValueError(
                    'When using virtual_batch_size, the batch dimension '
                    'must be 0 and thus axis cannot include 0')
            if self.adjustment is not None:
                raise ValueError(
                    'When using virtual_batch_size, adjustment cannot '
                    'be specified')

        if self.fused:
            # Currently fused batch norm doesn't support renorm. It also only supports
            # an input tensor of rank 4 and a channel dimension on axis 1 or 3.
            # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
            # output back to its original shape accordingly.
            self.fused = (not self.renorm and ndims == 4
                          and self.axis in [[1], [3]]
                          and self.virtual_batch_size is None
                          and self.adjustment is None)
            # TODO(chrisying): fused batch norm is currently not supported for
            # multi-axis batch norm and by extension virtual batches. In some cases,
            # it might be possible to use fused batch norm but would require reshaping
            # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
            # particularly tricky. A compromise might be to just support the most
            # common use case (turning 5D w/ virtual batch to NCHW)

        if self.fused:
            if self.axis == [1]:
                self._data_format = 'NCHW'
            elif self.axis == [3]:
                self._data_format = 'NHWC'
            else:
                raise ValueError(
                    'Unsupported axis, fused batch norm only supports '
                    'axis == [1] or axis == [3]')

        # Raise parameters of fp16 batch norm to fp32
        if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16:
            param_dtype = dtypes.float32
        else:
            param_dtype = self.dtype or dtypes.float32

        axis_to_dim = {x: input_shape[x].value for x in self.axis}
        for x in axis_to_dim:
            if axis_to_dim[x] is None:
                raise ValueError(
                    'Input has undefined `axis` dimension. Input shape: ',
                    input_shape)
        self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)

        if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
            # Single axis batch norm (most common/default use-case)
            param_shape = (list(axis_to_dim.values())[0], )
        else:
            # Parameter shape is the original shape but with 1 in all non-axis dims
            param_shape = [
                axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)
            ]
            if self.virtual_batch_size is not None:
                # When using virtual batches, add an extra dim at index 1
                param_shape.insert(1, 1)
                for idx, x in enumerate(self.axis):
                    self.axis[idx] = x + 1  # Account for added dimension

        if self.scale:
            self.gamma = self.add_weight(name='gamma',
                                         shape=param_shape,
                                         dtype=param_dtype,
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint,
                                         trainable=True)
        else:
            self.gamma = None
            if self.fused:
                self._gamma_const = array_ops.constant(1.0,
                                                       dtype=param_dtype,
                                                       shape=param_shape)

        if self.center:
            self.beta = self.add_weight(name='beta',
                                        shape=param_shape,
                                        dtype=param_dtype,
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint,
                                        trainable=True)
        else:
            self.beta = None
            if self.fused:
                self._beta_const = array_ops.constant(0.0,
                                                      dtype=param_dtype,
                                                      shape=param_shape)

        try:
            # Disable variable partitioning when creating the moving mean and variance
            if hasattr(self, '_scope') and self._scope:
                partitioner = self._scope.partitioner
                self._scope.set_partitioner(None)
            else:
                partitioner = None
            self.moving_mean = self._add_tower_local_variable(
                name='moving_mean',
                shape=param_shape,
                dtype=param_dtype,
                initializer=self.moving_mean_initializer,
                trainable=False)

            self.moving_variance = self._add_tower_local_variable(
                name='moving_variance',
                shape=param_shape,
                dtype=param_dtype,
                initializer=self.moving_variance_initializer,
                trainable=False)

            if self.renorm:
                # Create variables to maintain the moving mean and standard deviation.
                # These are used in training and thus are different from the moving
                # averages above. The renorm variables are colocated with moving_mean
                # and moving_variance.
                # NOTE: below, the outer `with device` block causes the current device
                # stack to be cleared. The nested ones use a `lambda` to set the desired
                # device and ignore any devices that may be set by the custom getter.
                def _renorm_variable(name, shape):
                    var = self._add_tower_local_variable(
                        name=name,
                        shape=shape,
                        dtype=param_dtype,
                        initializer=init_ops.zeros_initializer(),
                        trainable=False)
                    return var

                with distribute_lib.get_distribution_strategy(
                ).colocate_vars_with(self.moving_mean):
                    self.renorm_mean = _renorm_variable(
                        'renorm_mean', param_shape)
                    self.renorm_mean_weight = _renorm_variable(
                        'renorm_mean_weight', ())
                # We initialize renorm_stddev to 0, and maintain the (0-initialized)
                # renorm_stddev_weight. This allows us to (1) mix the average
                # stddev with the minibatch stddev early in training, and (2) compute
                # the unbiased average stddev by dividing renorm_stddev by the weight.
                with distribute_lib.get_distribution_strategy(
                ).colocate_vars_with(self.moving_variance):
                    self.renorm_stddev = _renorm_variable(
                        'renorm_stddev', param_shape)
                    self.renorm_stddev_weight = _renorm_variable(
                        'renorm_stddev_weight', ())
        finally:
            if partitioner:
                self._scope.set_partitioner(partitioner)
        self.built = True
예제 #24
0
  def compute_gradients(self, loss, var_list=None,
                        gate_gradients=GATE_OP,
                        aggregation_method=None,
                        colocate_gradients_with_ops=False,
                        grad_loss=None):
    """Compute gradients of `loss` for the variables in `var_list`.

    This is the first part of `minimize()`.  It returns a list
    of (gradient, variable) pairs where "gradient" is the gradient
    for "variable".  Note that "gradient" can be a `Tensor`, an
    `IndexedSlices`, or `None` if there is no gradient for the
    given variable.

    Args:
      loss: A Tensor containing the value to minimize or a callable taking
        no arguments which returns the value to minimize. When eager execution
        is enabled it must be a callable.
      var_list: Optional list or tuple of `tf.Variable` to update to minimize
        `loss`.  Defaults to the list of variables collected in the graph
        under the key `GraphKeys.TRAINABLE_VARIABLES`.
      gate_gradients: How to gate the computation of gradients.  Can be
        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
      aggregation_method: Specifies the method used to combine gradient terms.
        Valid values are defined in the class `AggregationMethod`.
      colocate_gradients_with_ops: If True, try colocating gradients with
        the corresponding op.
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.

    Returns:
      A list of (gradient, variable) pairs. Variable is always present, but
      gradient can be `None`.

    Raises:
      TypeError: If `var_list` contains anything else than `Variable` objects.
      ValueError: If some arguments are invalid.
      RuntimeError: If called with eager execution enabled and `loss` is
        not callable.

    @compatibility(eager)
    When eager execution is enabled, `gate_gradients`, `aggregation_method`,
    and `colocate_gradients_with_ops` are ignored.
    @end_compatibility
    """
    if callable(loss):
      with backprop.GradientTape() as tape:
        if var_list is not None:
          tape.watch(var_list)
        loss_value = loss()

        # Scale loss if using a "mean" loss reduction and multiple towers.
        # Have to be careful to call distribute_lib.get_loss_reduction()
        # *after* loss() is evaluated, so we know what loss reduction it uses.
        # TODO(josh11b): Test that we handle weight decay in a reasonable way.
        if (distribute_lib.get_loss_reduction() ==
            variable_scope.VariableAggregation.MEAN):
          num_towers = distribute_lib.get_distribution_strategy().num_towers
          if num_towers > 1:
            loss_value *= (1. / num_towers)

      if var_list is None:
        var_list = tape.watched_variables()
      grads = tape.gradient(loss_value, var_list, grad_loss)
      return list(zip(grads, var_list))

    # Non-callable/Tensor loss case
    if context.executing_eagerly():
      raise RuntimeError(
          "`loss` passed to Optimizer.compute_gradients should "
          "be a function when eager execution is enabled.")

    # Scale loss if using a "mean" loss reduction and multiple towers.
    if (distribute_lib.get_loss_reduction() ==
        variable_scope.VariableAggregation.MEAN):
      num_towers = distribute_lib.get_distribution_strategy().num_towers
      if num_towers > 1:
        loss *= (1. / num_towers)

    if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
                              Optimizer.GATE_GRAPH]:
      raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
                       "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
                       gate_gradients)
    self._assert_valid_dtypes([loss])
    if grad_loss is not None:
      self._assert_valid_dtypes([grad_loss])
    if var_list is None:
      var_list = (
          variables.trainable_variables() +
          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
    else:
      var_list = nest.flatten(var_list)
    # pylint: disable=protected-access
    var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
    # pylint: enable=protected-access
    processors = [_get_processor(v) for v in var_list]
    if not var_list:
      raise ValueError("No variables to optimize.")
    var_refs = [p.target() for p in processors]
    grads = gradients.gradients(
        loss, var_refs, grad_ys=grad_loss,
        gate_gradients=(gate_gradients == Optimizer.GATE_OP),
        aggregation_method=aggregation_method,
        colocate_gradients_with_ops=colocate_gradients_with_ops)
    if gate_gradients == Optimizer.GATE_GRAPH:
      grads = control_flow_ops.tuple(grads)
    grads_and_vars = list(zip(grads, var_list))
    self._assert_valid_dtypes(
        [v for g, v in grads_and_vars
         if g is not None and v.dtype != dtypes.resource])
    return grads_and_vars