예제 #1
0
def _assert_in_default_state(t):
  t.assertIs(distribution_strategy_context._get_default_tower_context(),
             distribution_strategy_context.get_tower_context())
  t.assertIs(None, distribution_strategy_context.get_cross_tower_context())
  t.assertIs(distribution_strategy_context._get_default_distribution_strategy(),
             distribution_strategy_context.get_distribution_strategy())
  t.assertFalse(distribution_strategy_context.has_distribution_strategy())
예제 #2
0
def create_slot(primary, val, name, colocate_with_primary=True):
  """Create a slot initialized to the given value.

  The type of the slot is determined by the given value.

  Args:
    primary: The primary `Variable` or `Tensor`.
    val: A `Tensor` specifying the initial value of the slot.
    name: Name to use for the slot variable.
    colocate_with_primary: Boolean.  If True the slot is located
      on the same device as `primary`.

  Returns:
    A `Variable` object.
  """
  # Scope the slot name in the namespace of the primary variable.
  # Set "primary.op.name + '/' + name" as default name, so the scope name of
  # optimizer can be shared when reuse is True. Meanwhile when reuse is False
  # and the same name has been previously used, the scope name will add '_N'
  # as suffix for unique identifications.
  validate_shape = val.get_shape().is_fully_defined()
  if context.executing_eagerly():
    prefix = primary._shared_name  # pylint: disable=protected-access
  else:
    prefix = primary.op.name
  with variable_scope.variable_scope(None, prefix + "/" + name):
    if colocate_with_primary:
      distribution_strategy = (
          distribution_strategy_context.get_distribution_strategy())
      with distribution_strategy.colocate_vars_with(primary):
        return _create_slot_var(primary, val, "", validate_shape, None, None)
    else:
      return _create_slot_var(primary, val, "", validate_shape, None, None)
예제 #3
0
  def _create_non_slot_variable(self, initial_value, name, colocate_with):
    """Add an extra variable, not associated with a slot."""
    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
    eager = context.executing_eagerly()
    graph = None if eager else colocate_with.graph

    key = (name, graph)
    v = self._non_slot_dict.get(key, None)
    if v is None:
      self._maybe_initialize_checkpointable()
      distribution_strategy = distribute_ctx.get_distribution_strategy()
      with distribution_strategy.colocate_vars_with(colocate_with):
        if eager:
          restored_initial_value = self._preload_simple_restoration(
              name=name, shape=None)
          if restored_initial_value is not None:
            initial_value = restored_initial_value
        v = variable_scope.variable(initial_value, name=name, trainable=False)
      # Restore this variable by name if necessary, but don't add a
      # Checkpointable dependency. Optimizers return the current graph's
      # non-slot variables from _checkpoint_dependencies explicitly rather
      # than unconditionally adding dependencies (since there may be multiple
      # non-slot variables with the same name in different graphs, trying to
      # save all of them would result in errors).
      self._handle_deferred_dependencies(name=name, checkpointable=v)
      self._non_slot_dict[key] = v

    return v
예제 #4
0
  def _assign_func(self, *args, **kwargs):
    f = kwargs.pop("f")
    if distribution_strategy_context.get_cross_tower_context():
      update_device = distribute_lib.get_update_device()
      if update_device is not None:
        # We are calling an assign function in an update context.
        return f(self._v, *args, **kwargs)

      # We are calling an assign function in cross tower context, wrap it in an
      # update call.
      return distribution_strategy_context.get_distribution_strategy().update(
          self, f, *args, **kwargs)
    else:
      assert distribution_strategy_context.get_tower_context()
      # We are calling an assign function in tower context.
      # We reduce the value we want to assign/add/sub. More details about how we
      # handle the different use cases can be found in the _reduce method.
      # We call the function with the reduced value.
      if self._aggregation == vs.VariableAggregation.NONE:
        raise ValueError("You must specify an aggregation method to update a "
                         "a variable in Tower Context.")

      def merge_fn(strategy, value, *other_args, **other_kwargs):
        return strategy.update(
            self, f,
            strategy.reduce(
                aggregation=self._aggregation, value=value, destinations=self),
            *other_args, **other_kwargs)

      return distribution_strategy_context.get_tower_context().merge_call(
          merge_fn, *args, **kwargs)
예제 #5
0
    def _assign_func(self, *args, **kwargs):
        f = kwargs.pop("f")
        if distribution_strategy_context.get_cross_tower_context():
            update_device = distribute_lib.get_update_device()
            if update_device is not None:
                # We are calling an assign function in an update context.
                return f(self._v, *args, **kwargs)

            # We are calling an assign function in cross tower context, wrap it in an
            # update call.
            return distribution_strategy_context.get_distribution_strategy(
            ).update(self, f, *args, **kwargs)
        else:
            assert distribution_strategy_context.get_tower_context()
            # We are calling an assign function in tower context.
            # We reduce the value we want to assign/add/sub. More details about how we
            # handle the different use cases can be found in the _reduce method.
            # We call the function with the reduced value.
            if self._aggregation == vs.VariableAggregation.NONE:
                raise ValueError(
                    "You must specify an aggregation method to update a "
                    "a variable in Tower Context.")

            def merge_fn(strategy, value, *other_args, **other_kwargs):
                return strategy.update(
                    self, f,
                    strategy.reduce(aggregation=self._aggregation,
                                    value=value,
                                    destinations=self), *other_args,
                    **other_kwargs)

            return distribution_strategy_context.get_tower_context(
            ).merge_call(merge_fn, *args, **kwargs)
def create_slot(primary, val, name, colocate_with_primary=True):
    """Create a slot initialized to the given value.

  The type of the slot is determined by the given value.

  Args:
    primary: The primary `Variable` or `Tensor`.
    val: A `Tensor` specifying the initial value of the slot.
    name: Name to use for the slot variable.
    colocate_with_primary: Boolean.  If True the slot is located
      on the same device as `primary`.

  Returns:
    A `Variable` object.
  """
    # Scope the slot name in the namespace of the primary variable.
    # Set "primary.op.name + '/' + name" as default name, so the scope name of
    # optimizer can be shared when reuse is True. Meanwhile when reuse is False
    # and the same name has been previously used, the scope name will add '_N'
    # as suffix for unique identifications.
    validate_shape = val.get_shape().is_fully_defined()
    if context.executing_eagerly():
        prefix = primary._shared_name  # pylint: disable=protected-access
    else:
        prefix = primary.op.name
    with variable_scope.variable_scope(None, prefix + "/" + name):
        if colocate_with_primary:
            distribution_strategy = (
                distribution_strategy_context.get_distribution_strategy())
            with distribution_strategy.colocate_vars_with(primary):
                return _create_slot_var(primary, val, "", validate_shape, None,
                                        None)
        else:
            return _create_slot_var(primary, val, "", validate_shape, None,
                                    None)
예제 #7
0
 def _scale_loss(loss_value):
   if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
     num_replicas = \
       distribute_ctx.get_distribution_strategy().num_replicas_in_sync
     if num_replicas > 1:
       loss_value *= (1. / num_replicas)
   return loss_value
예제 #8
0
 def _scale_loss(loss_value):
   if (distribute_lib.get_loss_reduction() ==
       variable_scope.VariableAggregation.MEAN):
     num_replicas = \
       distribute_ctx.get_distribution_strategy().num_replicas_in_sync
     if num_replicas > 1:
       loss_value *= (1. / num_replicas)
   return loss_value
예제 #9
0
 def merge_fn(dist, s):
   self.assertIs(
       distribution_strategy_context._get_default_distribution_strategy(),
       dist)
   self.assertIs(None, distribution_strategy_context.get_tower_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_cross_tower_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_distribution_strategy())
   self.assertFalse(
       distribution_strategy_context.has_distribution_strategy())
   return "foo_" + s
예제 #10
0
 def run_fn():
   tower_context = distribution_strategy_context.get_tower_context()
   self.assertTrue(tower_context is not None)
   self.assertIs(None,
                 distribution_strategy_context.get_cross_tower_context())
   self.assertTrue(distribution_strategy_context.has_distribution_strategy())
   self.assertIs(dist,
                 distribution_strategy_context.get_distribution_strategy())
   self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
예제 #11
0
    def _assign_func(self, *args, **kwargs):
        f = kwargs.pop("f")
        if distribution_strategy_context.get_cross_tower_context():
            update_device = distribute_lib.get_update_device()
            if update_device is not None:
                # We are calling an assign function on the mirrored variable in an
                # update context.
                v = self.get(device=update_device)
                return f(v, *args, **kwargs)

            # We are calling assign on the mirrored variable in cross tower context,
            # use update to update the variable.
            strategy = distribution_strategy_context.get_distribution_strategy(
            )
            updates = strategy.update(self, f, *args, **kwargs)
            grouped = strategy.group(updates)
            if isinstance(updates,
                          DistributedValues) and updates.is_tensor_like:
                # Make sure we run all updates. Without this, something like
                # session.run(mirrored_var.assign*(...)) may only update one tower.
                index = {}
                for d in updates.devices:
                    with ops.device(d), ops.control_dependencies([grouped]):
                        index[d] = array_ops.identity(updates.get(d))
                return Mirrored(index)
            else:
                return grouped
        else:
            _assert_tower_context()
            # We are calling an assign function on the mirrored variable in tower
            # context.
            # We reduce the value we want to assign/add/sub. More details about how we
            # handle the different use cases can be found in the _reduce method.
            # We call the function on each of the mirrored variables with the reduced
            # value.
            if self._aggregation == vs.VariableAggregation.NONE:
                raise ValueError(
                    "You must specify an aggregation method to update a "
                    "MirroredVariable in Tower Context.")

            def merge_fn(strategy, value, *other_args, **other_kwargs):
                return strategy.update(
                    self, f,
                    strategy.reduce(aggregation=self._aggregation,
                                    value=value,
                                    destinations=self), *other_args,
                    **other_kwargs)

            return distribution_strategy_context.get_tower_context(
            ).merge_call(merge_fn, *args, **kwargs)
예제 #12
0
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribution_strategy_context.get_tower_context())
     self.assertIs(dist,
                   distribution_strategy_context.get_cross_tower_context())
     self.assertTrue(distribution_strategy_context.has_distribution_strategy())
     self.assertIs(dist,
                   distribution_strategy_context.get_distribution_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
예제 #13
0
    def set_last_step_output(
            self,
            name,
            output,
            aggregation=variables_lib.VariableAggregation.NONE):
        """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      aggregation: Aggregation method to use to aggregate outputs from multiple
        towers. Required if `set_last_step_output` is called in a tower context.
        Optional in cross_tower_context.
        When present, the outputs from all the towers are aggregated using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and aggregation is set, output
        must be a `PerDevice` value.
        The aggregation method is also recorded in a dictionary
        `_last_step_outputs_aggregations` for later interpreting of the
        outputs as already reduced or not.

    """
        if distribution_strategy_context.get_cross_tower_context():
            self._last_step_outputs_aggregations[name] = aggregation
            if aggregation is variables_lib.VariableAggregation.NONE:
                self._last_step_outputs[name] = output
            else:
                distribution = distribution_strategy_context.get_distribution_strategy(
                )
                self._last_step_outputs[name] = distribution.reduce(
                    aggregation, output, destinations="/device:CPU:0")
        else:
            assert aggregation is not variables_lib.VariableAggregation.NONE

            def merge_fn(distribution, value):
                self._last_step_outputs[name] = distribution.reduce(
                    aggregation, value, destinations="/device:CPU:0")
                # Setting this inside the `merge_fn` because all towers share the same
                # context object, so it's more robust to set it only once (even if all
                # the towers are trying to set the same value).
                self._last_step_outputs_aggregations[name] = aggregation

            distribution_strategy_context.get_tower_context().merge_call(
                merge_fn, output)
예제 #14
0
  def _assign_func(self, *args, **kwargs):
    f = kwargs.pop("f")
    if distribution_strategy_context.get_cross_tower_context():
      update_device = distribute_lib.get_update_device()
      if update_device is not None:
        # We are calling an assign function on the mirrored variable in an
        # update context.
        v = self.get(device=update_device)
        return f(v, *args, **kwargs)

      # We are calling assign on the mirrored variable in cross tower context,
      # use update to update the variable.
      strategy = distribution_strategy_context.get_distribution_strategy()
      updates = strategy.update(self, f, *args, **kwargs)
      grouped = strategy.group(updates)
      if isinstance(updates, DistributedValues) and updates.is_tensor_like:
        # Make sure we run all updates. Without this, something like
        # session.run(mirrored_var.assign*(...)) may only update one tower.
        index = {}
        for d in updates.devices:
          with ops.device(d), ops.control_dependencies([grouped]):
            index[d] = array_ops.identity(updates.get(d))
        return Mirrored(index)
      else:
        return grouped
    else:
      _assert_tower_context()
      # We are calling an assign function on the mirrored variable in tower
      # context.
      # We reduce the value we want to assign/add/sub. More details about how we
      # handle the different use cases can be found in the _reduce method.
      # We call the function on each of the mirrored variables with the reduced
      # value.
      if self._aggregation == vs.VariableAggregation.NONE:
        raise ValueError("You must specify an aggregation method to update a "
                         "MirroredVariable in Tower Context.")

      def merge_fn(strategy, value, *other_args, **other_kwargs):
        return strategy.update(
            self, f,
            strategy.reduce(
                aggregation=self._aggregation, value=value, destinations=self),
            *other_args, **other_kwargs)

      return distribution_strategy_context.get_tower_context().merge_call(
          merge_fn, *args, **kwargs)
예제 #15
0
  def set_last_step_output(self, name, output,
                           aggregation=variables_lib.VariableAggregation.NONE):
    """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      aggregation: Aggregation method to use to aggregate outputs from multiple
        towers. Required if `set_last_step_output` is called in a tower context.
        Optional in cross_tower_context.
        When present, the outputs from all the towers are aggregated using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and aggregation is set, output
        must be a `PerDevice` value.
        The aggregation method is also recorded in a dictionary
        `_last_step_outputs_aggregations` for later interpreting of the
        outputs as already reduced or not.

    """
    if distribution_strategy_context.get_cross_tower_context():
      self._last_step_outputs_aggregations[name] = aggregation
      if aggregation is variables_lib.VariableAggregation.NONE:
        self._last_step_outputs[name] = output
      else:
        distribution = distribution_strategy_context.get_distribution_strategy()
        self._last_step_outputs[name] = distribution.reduce(
            aggregation, output, destinations="/device:CPU:0")
    else:
      assert aggregation is not variables_lib.VariableAggregation.NONE
      def merge_fn(distribution, value):
        self._last_step_outputs[name] = distribution.reduce(
            aggregation, value, destinations="/device:CPU:0")
        # Setting this inside the `merge_fn` because all towers share the same
        # context object, so it's more robust to set it only once (even if all
        # the towers are trying to set the same value).
        self._last_step_outputs_aggregations[name] = aggregation

      distribution_strategy_context.get_tower_context().merge_call(
          merge_fn, output)
예제 #16
0
  def build(self, input_shape):
    input_shape = tensor_shape.TensorShape(input_shape)
    if not input_shape.ndims:
      raise ValueError('Input has undefined rank:', input_shape)
    ndims = len(input_shape)

    # Convert axis to list and resolve negatives
    if isinstance(self.axis, int):
      self.axis = [self.axis]

    if not isinstance(self.axis, list):
      raise TypeError('axis must be int or list, type given: %s'
                      % type(self.axis))

    for idx, x in enumerate(self.axis):
      if x < 0:
        self.axis[idx] = ndims + x

    # Validate axes
    for x in self.axis:
      if x < 0 or x >= ndims:
        raise ValueError('Invalid axis: %d' % x)
    if len(self.axis) != len(set(self.axis)):
      raise ValueError('Duplicate axis: %s' % self.axis)

    if self.virtual_batch_size is not None:
      if self.virtual_batch_size <= 0:
        raise ValueError('virtual_batch_size must be a positive integer that '
                         'divides the true batch size of the input Tensor')
      # If using virtual batches, the first dimension must be the batch
      # dimension and cannot be the batch norm axis
      if 0 in self.axis:
        raise ValueError('When using virtual_batch_size, the batch dimension '
                         'must be 0 and thus axis cannot include 0')
      if self.adjustment is not None:
        raise ValueError('When using virtual_batch_size, adjustment cannot '
                         'be specified')

    if self.fused:
      # Currently fused batch norm doesn't support renorm. It also only supports
      # an input tensor of rank 4 and a channel dimension on axis 1 or 3.
      # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
      # output back to its original shape accordingly.
      self.fused = (not self.renorm and
                    ndims == 4 and
                    self.axis in [[1], [3]] and
                    self.virtual_batch_size is None and
                    self.adjustment is None)
      # TODO(chrisying): fused batch norm is currently not supported for
      # multi-axis batch norm and by extension virtual batches. In some cases,
      # it might be possible to use fused batch norm but would require reshaping
      # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
      # particularly tricky. A compromise might be to just support the most
      # common use case (turning 5D w/ virtual batch to NCHW)

    if self.fused:
      if self.axis == [1]:
        self._data_format = 'NCHW'
      elif self.axis == [3]:
        self._data_format = 'NHWC'
      else:
        raise ValueError('Unsupported axis, fused batch norm only supports '
                         'axis == [1] or axis == [3]')

    # Raise parameters of fp16 batch norm to fp32
    if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16:
      param_dtype = dtypes.float32
    else:
      param_dtype = self.dtype or dtypes.float32

    axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
    for x in axis_to_dim:
      if axis_to_dim[x] is None:
        raise ValueError('Input has undefined `axis` dimension. Input shape: ',
                         input_shape)
    self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)

    if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
      # Single axis batch norm (most common/default use-case)
      param_shape = (list(axis_to_dim.values())[0],)
    else:
      # Parameter shape is the original shape but with 1 in all non-axis dims
      param_shape = [axis_to_dim[i] if i in axis_to_dim
                     else 1 for i in range(ndims)]
      if self.virtual_batch_size is not None:
        # When using virtual batches, add an extra dim at index 1
        param_shape.insert(1, 1)
        for idx, x in enumerate(self.axis):
          self.axis[idx] = x + 1      # Account for added dimension

    if self.scale:
      self.gamma = self.add_weight(
          name='gamma',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.gamma_initializer,
          regularizer=self.gamma_regularizer,
          constraint=self.gamma_constraint,
          trainable=True)
    else:
      self.gamma = None
      if self.fused:
        self._gamma_const = array_ops.constant(
            1.0, dtype=param_dtype, shape=param_shape)

    if self.center:
      self.beta = self.add_weight(
          name='beta',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.beta_initializer,
          regularizer=self.beta_regularizer,
          constraint=self.beta_constraint,
          trainable=True)
    else:
      self.beta = None
      if self.fused:
        self._beta_const = array_ops.constant(
            0.0, dtype=param_dtype, shape=param_shape)

    try:
      # Disable variable partitioning when creating the moving mean and variance
      if hasattr(self, '_scope') and self._scope:
        partitioner = self._scope.partitioner
        self._scope.set_partitioner(None)
      else:
        partitioner = None
      self.moving_mean = self.add_weight(
          name='moving_mean',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.moving_mean_initializer,
          synchronization=tf_variables.VariableSynchronization.ON_READ,
          trainable=False,
          aggregation=tf_variables.VariableAggregation.MEAN)

      self.moving_variance = self.add_weight(
          name='moving_variance',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.moving_variance_initializer,
          synchronization=tf_variables.VariableSynchronization.ON_READ,
          trainable=False,
          aggregation=tf_variables.VariableAggregation.MEAN)

      if self.renorm:
        # Create variables to maintain the moving mean and standard deviation.
        # These are used in training and thus are different from the moving
        # averages above. The renorm variables are colocated with moving_mean
        # and moving_variance.
        # NOTE: below, the outer `with device` block causes the current device
        # stack to be cleared. The nested ones use a `lambda` to set the desired
        # device and ignore any devices that may be set by the custom getter.
        def _renorm_variable(name, shape):
          var = self.add_weight(
              name=name,
              shape=shape,
              dtype=param_dtype,
              initializer=init_ops.zeros_initializer(),
              synchronization=tf_variables.VariableSynchronization.ON_READ,
              trainable=False,
              aggregation=tf_variables.VariableAggregation.MEAN)
          return var

        with distribution_strategy_context.get_distribution_strategy(
        ).colocate_vars_with(self.moving_mean):
          self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
          self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
        # We initialize renorm_stddev to 0, and maintain the (0-initialized)
        # renorm_stddev_weight. This allows us to (1) mix the average
        # stddev with the minibatch stddev early in training, and (2) compute
        # the unbiased average stddev by dividing renorm_stddev by the weight.
        with distribution_strategy_context.get_distribution_strategy(
        ).colocate_vars_with(self.moving_variance):
          self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
          self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight',
                                                       ())
    finally:
      if partitioner:
        self._scope.set_partitioner(partitioner)
    self.built = True
예제 #17
0
 def tensor():
     return distribution_strategy_context.get_distribution_strategy(
     ).read_var(tower_local_variable)
예제 #18
0
 def read_value(self):
     return distribution_strategy_context.get_distribution_strategy(
     ).read_var(self)
예제 #19
0
  def build(self, input_shape):
    input_shape = tensor_shape.TensorShape(input_shape)
    if not input_shape.ndims:
      raise ValueError('Input has undefined rank:', input_shape)
    ndims = len(input_shape)

    # Convert axis to list and resolve negatives
    if isinstance(self.axis, int):
      self.axis = [self.axis]

    if not isinstance(self.axis, list):
      raise TypeError('axis must be int or list, type given: %s'
                      % type(self.axis))

    for idx, x in enumerate(self.axis):
      if x < 0:
        self.axis[idx] = ndims + x

    # Validate axes
    for x in self.axis:
      if x < 0 or x >= ndims:
        raise ValueError('Invalid axis: %d' % x)
    if len(self.axis) != len(set(self.axis)):
      raise ValueError('Duplicate axis: %s' % self.axis)

    if self.virtual_batch_size is not None:
      if self.virtual_batch_size <= 0:
        raise ValueError('virtual_batch_size must be a positive integer that '
                         'divides the true batch size of the input Tensor')
      # If using virtual batches, the first dimension must be the batch
      # dimension and cannot be the batch norm axis
      if 0 in self.axis:
        raise ValueError('When using virtual_batch_size, the batch dimension '
                         'must be 0 and thus axis cannot include 0')
      if self.adjustment is not None:
        raise ValueError('When using virtual_batch_size, adjustment cannot '
                         'be specified')

    if self.fused:
      # Currently fused batch norm doesn't support renorm. It also only supports
      # an input tensor of rank 4 and a channel dimension on axis 1 or 3.
      # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
      # output back to its original shape accordingly.
      self.fused = (not self.renorm and
                    ndims == 4 and
                    self.axis in [[1], [3]] and
                    self.virtual_batch_size is None and
                    self.adjustment is None)
      # TODO(chrisying): fused batch norm is currently not supported for
      # multi-axis batch norm and by extension virtual batches. In some cases,
      # it might be possible to use fused batch norm but would require reshaping
      # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
      # particularly tricky. A compromise might be to just support the most
      # common use case (turning 5D w/ virtual batch to NCHW)

    if self.fused:
      if self.axis == [1]:
        self._data_format = 'NCHW'
      elif self.axis == [3]:
        self._data_format = 'NHWC'
      else:
        raise ValueError('Unsupported axis, fused batch norm only supports '
                         'axis == [1] or axis == [3]')

    # Raise parameters of fp16 batch norm to fp32
    if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16:
      param_dtype = dtypes.float32
    else:
      param_dtype = self.dtype or dtypes.float32

    axis_to_dim = {x: input_shape[x].value for x in self.axis}
    for x in axis_to_dim:
      if axis_to_dim[x] is None:
        raise ValueError('Input has undefined `axis` dimension. Input shape: ',
                         input_shape)
    self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)

    if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
      # Single axis batch norm (most common/default use-case)
      param_shape = (list(axis_to_dim.values())[0],)
    else:
      # Parameter shape is the original shape but with 1 in all non-axis dims
      param_shape = [axis_to_dim[i] if i in axis_to_dim
                     else 1 for i in range(ndims)]
      if self.virtual_batch_size is not None:
        # When using virtual batches, add an extra dim at index 1
        param_shape.insert(1, 1)
        for idx, x in enumerate(self.axis):
          self.axis[idx] = x + 1      # Account for added dimension

    if self.scale:
      self.gamma = self.add_weight(
          name='gamma',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.gamma_initializer,
          regularizer=self.gamma_regularizer,
          constraint=self.gamma_constraint,
          trainable=True)
    else:
      self.gamma = None
      if self.fused:
        self._gamma_const = array_ops.constant(
            1.0, dtype=param_dtype, shape=param_shape)

    if self.center:
      self.beta = self.add_weight(
          name='beta',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.beta_initializer,
          regularizer=self.beta_regularizer,
          constraint=self.beta_constraint,
          trainable=True)
    else:
      self.beta = None
      if self.fused:
        self._beta_const = array_ops.constant(
            0.0, dtype=param_dtype, shape=param_shape)

    try:
      # Disable variable partitioning when creating the moving mean and variance
      if hasattr(self, '_scope') and self._scope:
        partitioner = self._scope.partitioner
        self._scope.set_partitioner(None)
      else:
        partitioner = None
      self.moving_mean = self.add_weight(
          name='moving_mean',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.moving_mean_initializer,
          synchronization=tf_variables.VariableSynchronization.ON_READ,
          trainable=False,
          aggregation=tf_variables.VariableAggregation.MEAN)

      self.moving_variance = self.add_weight(
          name='moving_variance',
          shape=param_shape,
          dtype=param_dtype,
          initializer=self.moving_variance_initializer,
          synchronization=tf_variables.VariableSynchronization.ON_READ,
          trainable=False,
          aggregation=tf_variables.VariableAggregation.MEAN)

      if self.renorm:
        # Create variables to maintain the moving mean and standard deviation.
        # These are used in training and thus are different from the moving
        # averages above. The renorm variables are colocated with moving_mean
        # and moving_variance.
        # NOTE: below, the outer `with device` block causes the current device
        # stack to be cleared. The nested ones use a `lambda` to set the desired
        # device and ignore any devices that may be set by the custom getter.
        def _renorm_variable(name, shape):
          var = self.add_weight(
              name=name,
              shape=shape,
              dtype=param_dtype,
              initializer=init_ops.zeros_initializer(),
              synchronization=tf_variables.VariableSynchronization.ON_READ,
              trainable=False,
              aggregation=tf_variables.VariableAggregation.MEAN)
          return var

        with distribution_strategy_context.get_distribution_strategy(
        ).colocate_vars_with(self.moving_mean):
          self.renorm_mean = _renorm_variable('renorm_mean', param_shape)
          self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
        # We initialize renorm_stddev to 0, and maintain the (0-initialized)
        # renorm_stddev_weight. This allows us to (1) mix the average
        # stddev with the minibatch stddev early in training, and (2) compute
        # the unbiased average stddev by dividing renorm_stddev by the weight.
        with distribution_strategy_context.get_distribution_strategy(
        ).colocate_vars_with(self.moving_variance):
          self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape)
          self.renorm_stddev_weight = _renorm_variable('renorm_stddev_weight',
                                                       ())
    finally:
      if partitioner:
        self._scope.set_partitioner(partitioner)
    self.built = True
예제 #20
0
  def compute_gradients(self, loss, var_list=None,
                        gate_gradients=GATE_OP,
                        aggregation_method=None,
                        colocate_gradients_with_ops=False,
                        grad_loss=None):
    """Compute gradients of `loss` for the variables in `var_list`.

    This is the first part of `minimize()`.  It returns a list
    of (gradient, variable) pairs where "gradient" is the gradient
    for "variable".  Note that "gradient" can be a `Tensor`, an
    `IndexedSlices`, or `None` if there is no gradient for the
    given variable.

    Args:
      loss: A Tensor containing the value to minimize or a callable taking
        no arguments which returns the value to minimize. When eager execution
        is enabled it must be a callable.
      var_list: Optional list or tuple of `tf.Variable` to update to minimize
        `loss`.  Defaults to the list of variables collected in the graph
        under the key `GraphKeys.TRAINABLE_VARIABLES`.
      gate_gradients: How to gate the computation of gradients.  Can be
        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
      aggregation_method: Specifies the method used to combine gradient terms.
        Valid values are defined in the class `AggregationMethod`.
      colocate_gradients_with_ops: If True, try colocating gradients with
        the corresponding op.
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.

    Returns:
      A list of (gradient, variable) pairs. Variable is always present, but
      gradient can be `None`.

    Raises:
      TypeError: If `var_list` contains anything else than `Variable` objects.
      ValueError: If some arguments are invalid.
      RuntimeError: If called with eager execution enabled and `loss` is
        not callable.

    @compatibility(eager)
    When eager execution is enabled, `gate_gradients`, `aggregation_method`,
    and `colocate_gradients_with_ops` are ignored.
    @end_compatibility
    """
    if callable(loss):
      with backprop.GradientTape() as tape:
        if var_list is not None:
          tape.watch(var_list)
        loss_value = loss()

        # Scale loss if using a "mean" loss reduction and multiple towers.
        # Have to be careful to call distribute_lib.get_loss_reduction()
        # *after* loss() is evaluated, so we know what loss reduction it uses.
        # TODO(josh11b): Test that we handle weight decay in a reasonable way.
        if (distribute_lib.get_loss_reduction() ==
            variable_scope.VariableAggregation.MEAN):
          num_towers = distribution_strategy_context.get_distribution_strategy(
          ).num_towers
          if num_towers > 1:
            loss_value *= (1. / num_towers)

      if var_list is None:
        var_list = tape.watched_variables()
      grads = tape.gradient(loss_value, var_list, grad_loss)
      return list(zip(grads, var_list))

    # Non-callable/Tensor loss case
    if context.executing_eagerly():
      raise RuntimeError(
          "`loss` passed to Optimizer.compute_gradients should "
          "be a function when eager execution is enabled.")

    # Scale loss if using a "mean" loss reduction and multiple towers.
    if (distribute_lib.get_loss_reduction() ==
        variable_scope.VariableAggregation.MEAN):
      num_towers = distribution_strategy_context.get_distribution_strategy(
      ).num_towers
      if num_towers > 1:
        loss *= (1. / num_towers)

    if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
                              Optimizer.GATE_GRAPH]:
      raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
                       "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
                       gate_gradients)
    self._assert_valid_dtypes([loss])
    if grad_loss is not None:
      self._assert_valid_dtypes([grad_loss])
    if var_list is None:
      var_list = (
          variables.trainable_variables() +
          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
    else:
      var_list = nest.flatten(var_list)
    # pylint: disable=protected-access
    var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
    # pylint: enable=protected-access
    processors = [_get_processor(v) for v in var_list]
    if not var_list:
      raise ValueError("No variables to optimize.")
    var_refs = [p.target() for p in processors]
    grads = gradients.gradients(
        loss, var_refs, grad_ys=grad_loss,
        gate_gradients=(gate_gradients == Optimizer.GATE_OP),
        aggregation_method=aggregation_method,
        colocate_gradients_with_ops=colocate_gradients_with_ops)
    if gate_gradients == Optimizer.GATE_GRAPH:
      grads = control_flow_ops.tuple(grads)
    grads_and_vars = list(zip(grads, var_list))
    self._assert_valid_dtypes(
        [v for g, v in grads_and_vars
         if g is not None and v.dtype != dtypes.resource])
    return grads_and_vars
예제 #21
0
 def read_value(self):
   return distribution_strategy_context.get_distribution_strategy().read_var(
       self)
예제 #22
0
    def _restore_checkpoint(self,
                            master,
                            saver=None,
                            checkpoint_dir=None,
                            checkpoint_filename_with_path=None,
                            wait_for_checkpoint=False,
                            max_wait_secs=7200,
                            config=None):
        """Creates a `Session`, and tries to restore a checkpoint.


    Args:
      master: `String` representation of the TensorFlow master to use.
      saver: A `Saver` object used to restore a model.
      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
        dir will be used to restore.
      checkpoint_filename_with_path: Full file name path to the checkpoint file.
      wait_for_checkpoint: Whether to wait for checkpoint to become available.
      max_wait_secs: Maximum time to wait for checkpoints to become available.
      config: Optional `ConfigProto` proto used to configure the session.

    Returns:
      A pair (sess, is_restored) where 'is_restored' is `True` if
      the session could be restored, `False` otherwise.

    Raises:
      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
        set.
    """
        self._target = master

        # This is required to so that we initialize the TPU device before
        # restoring from checkpoint since we'll be placing variables on the device
        # and TPUInitialize wipes out the memory of the device.
        strategy = distribution_strategy_context.get_distribution_strategy()
        if strategy and hasattr(strategy.extended,
                                "_experimental_initialize_system"):
            strategy.extended._experimental_initialize_system()  # pylint: disable=protected-access

        sess = session.Session(self._target, graph=self._graph, config=config)
        if checkpoint_dir and checkpoint_filename_with_path:
            raise ValueError("Can not provide both checkpoint_dir and "
                             "checkpoint_filename_with_path.")
        # If either saver or checkpoint_* is not specified, cannot restore. Just
        # return.
        if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
            return sess, False

        if checkpoint_filename_with_path:
            saver.restore(sess, checkpoint_filename_with_path)
            return sess, True

        # Waits up until max_wait_secs for checkpoint to become available.
        wait_time = 0
        ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
        while not ckpt or not ckpt.model_checkpoint_path:
            if wait_for_checkpoint and wait_time < max_wait_secs:
                logging.info("Waiting for checkpoint to be available.")
                time.sleep(self._recovery_wait_secs)
                wait_time += self._recovery_wait_secs
                ckpt = checkpoint_management.get_checkpoint_state(
                    checkpoint_dir)
            else:
                return sess, False

        # Loads the checkpoint.
        saver.restore(sess, ckpt.model_checkpoint_path)
        saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
        return sess, True
예제 #23
0
 def tensor():
   return distribution_strategy_context.get_distribution_strategy().read_var(
       tower_local_variable)
예제 #24
0
    def build(self, input_shape):
        input_shape = tf.TensorShape(input_shape)
        if not input_shape.ndims:
            raise ValueError('Input has undefined rank:', input_shape)
        ndims = len(input_shape)

        # Convert axis to list and resolve negatives
        if isinstance(self.axis, int):
            self.axis = [self.axis]

        if not isinstance(self.axis, list):
            raise TypeError(
                'axis must be int or list, type given: %s' % type(self.axis))

        for idx, x in enumerate(self.axis):
            if x < 0:
                self.axis[idx] = ndims + x

        # Validate axes
        for x in self.axis:
            if x < 0 or x >= ndims:
                raise ValueError('Invalid axis: %d' % x)
        if len(self.axis) != len(set(self.axis)):
            raise ValueError('Duplicate axis: %s' % self.axis)

        # Raise parameters of fp16 batch norm to fp32
        if self.dtype == tf.float16 or self.dtype == tf.bfloat16:
            param_dtype = tf.float32
        else:
            param_dtype = self.dtype or tf.float32

        axis_to_dim = {x: input_shape[x] for x in self.axis}
        for x in axis_to_dim:
            if axis_to_dim[x] is None:
                raise ValueError(
                    'Input has undefined `axis` dimension. Input shape: ',
                    input_shape)
        self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)

        if self.binary_dense:
            param_shape = [1]
        elif len(axis_to_dim) == 1:
            # Single axis batch norm (most common/default use-case)
            param_shape = (list(axis_to_dim.values())[0], )
        else:
            # Parameter shape is the original shape but with 1 in all non-axis dims
            param_shape = [
                axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)
            ]

        if self.scale:
            self.gamma = self.add_weight(
                name='gamma',
                shape=param_shape,
                dtype=param_dtype,
                initializer=self.gamma_initializer,
                regularizer=self.gamma_regularizer,
                constraint=self.gamma_constraint,
                trainable=True)
        else:
            self.gamma = None

        if self.center:
            self.beta = self.add_weight(
                name='beta',
                shape=param_shape,
                dtype=param_dtype,
                initializer=self.beta_initializer,
                regularizer=self.beta_regularizer,
                constraint=self.beta_constraint,
                trainable=True)
        else:
            self.beta = None

        try:
            # Disable variable partitioning when creating the moving mean and variance
            if hasattr(self, '_scope') and self._scope:
                partitioner = self._scope.partitioner
                self._scope.set_partitioner(None)
            else:
                partitioner = None
            self.moving_mean = self.add_weight(
                name='moving_mean',
                shape=param_shape,
                dtype=param_dtype,
                initializer=self.moving_mean_initializer,
                synchronization=tf.VariableSynchronization.ON_READ,
                trainable=False,
                aggregation=tf.VariableAggregation.MEAN)

            self.moving_variance = self.add_weight(
                name='moving_variance',
                shape=param_shape,
                dtype=param_dtype,
                initializer=self.moving_variance_initializer,
                synchronization=tf.VariableSynchronization.ON_READ,
                trainable=False,
                aggregation=tf.VariableAggregation.MEAN)

            if self.renorm:
                # Create variables to maintain the moving mean and standard deviation.
                # These are used in training and thus are different from the moving
                # averages above. The renorm variables are colocated with moving_mean
                # and moving_variance.
                # NOTE: below, the outer `with device` block causes the current device
                # stack to be cleared. The nested ones use a `lambda` to set the desired
                # device and ignore any devices that may be set by the custom getter.
                def _renorm_variable(name, shape):
                    var = self.add_weight(
                        name=name,
                        shape=shape,
                        dtype=param_dtype,
                        initializer=tf.zeros_initializer(),
                        synchronization=tf.VariableSynchronization.ON_READ,
                        trainable=False,
                        aggregation=tf.VariableAggregation.MEAN)
                    return var

                with distribution_strategy_context.get_distribution_strategy(
                ).colocate_vars_with(self.moving_mean):
                    self.renorm_mean = _renorm_variable(
                        'renorm_mean', param_shape)
                    self.renorm_mean_weight = _renorm_variable(
                        'renorm_mean_weight', ())
                # We initialize renorm_stddev to 0, and maintain the (0-initialized)
                # renorm_stddev_weight. This allows us to (1) mix the average
                # stddev with the minibatch stddev early in training, and (2) compute
                # the unbiased average stddev by dividing renorm_stddev by the weight.
                with distribution_strategy_context.get_distribution_strategy(
                ).colocate_vars_with(self.moving_variance):
                    self.renorm_stddev = _renorm_variable(
                        'renorm_stddev', param_shape)
                    self.renorm_stddev_weight = _renorm_variable(
                        'renorm_stddev_weight', ())
        finally:
            if partitioner:
                self._scope.set_partitioner(partitioner)
        self.built = True
예제 #25
0
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This contains most of the synchronization implementation and also wraps the
    apply_gradients() from the real optimizer.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        compute_gradients().
      global_step: Optional Variable to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the Optimizer constructor.

    Returns:
      train_op: The op to dequeue a token so the replicas can exit this batch
      and start the next one. This is executed by each replica.

    Raises:
      ValueError: If the grads_and_vars is empty.
      ValueError: If global step is not provided, the staleness cannot be
        checked.
    """
    if not grads_and_vars:
      raise ValueError("Must supply at least one variable")

    if global_step is None:
      raise ValueError("Global step is required to check staleness")

    self._global_step = global_step
    train_ops = []
    aggregated_grad = []
    var_list = []

    # local_anchor op will be placed on this worker task by default.
    local_anchor = control_flow_ops.no_op()
    # Colocating local_step variable prevents it being placed on the PS.
    distribution_strategy = (
        distribution_strategy_context.get_distribution_strategy())
    with distribution_strategy.colocate_vars_with(local_anchor):
      self._local_step = variable_scope.variable(
          initial_value=0,
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES],
          dtype=global_step.dtype.base_dtype,
          name="sync_rep_local_step")

    self.local_step_init_op = state_ops.assign(self._local_step, global_step)
    chief_init_ops = [self.local_step_init_op]
    self.ready_for_local_init_op = variables.report_uninitialized_variables(
        variables.global_variables())

    with ops.name_scope(None, self._name):
      for grad, var in grads_and_vars:
        var_list.append(var)
        with ops.device(var.device):
          # Dense gradients.
          if grad is None:
            aggregated_grad.append(None)  # pass-through.
            continue
          elif isinstance(grad, ops.Tensor):
            grad_accum = data_flow_ops.ConditionalAccumulator(
                grad.dtype,
                shape=var.get_shape(),
                shared_name=var.name + "/grad_accum")
            train_ops.append(grad_accum.apply_grad(
                grad, local_step=self._local_step))
            aggregated_grad.append(grad_accum.take_grad(
                self._replicas_to_aggregate))
          else:
            if not isinstance(grad, ops.IndexedSlices):
              raise ValueError("Unknown grad type!")
            grad_accum = data_flow_ops.SparseConditionalAccumulator(
                grad.dtype, shape=(), shared_name=var.name + "/grad_accum")
            train_ops.append(grad_accum.apply_indexed_slices_grad(
                grad, local_step=self._local_step))
            aggregated_grad.append(grad_accum.take_indexed_slices_grad(
                self._replicas_to_aggregate))

          self._accumulator_list.append((grad_accum, var.device))

      aggregated_grads_and_vars = zip(aggregated_grad, var_list)

      # sync_op will be assigned to the same device as the global step.
      with ops.device(global_step.device), ops.name_scope(""):
        update_op = self._opt.apply_gradients(aggregated_grads_and_vars,
                                              global_step)

      # Create token queue.
      with ops.device(global_step.device), ops.name_scope(""):
        sync_token_queue = (
            data_flow_ops.FIFOQueue(-1,
                                    global_step.dtype.base_dtype,
                                    shapes=(),
                                    name="sync_token_q",
                                    shared_name="sync_token_q"))
        self._sync_token_queue = sync_token_queue

        # dummy_queue is passed to the queue runner. Don't use the real queues
        # because the queue runner doesn't automatically reopen it once it
        # closed queues in PS devices.
        dummy_queue = (
            data_flow_ops.FIFOQueue(1,
                                    types_pb2.DT_INT32,
                                    shapes=(),
                                    name="dummy_queue",
                                    shared_name="dummy_queue"))

      with ops.device(global_step.device), ops.name_scope(""):
        # Replicas have to wait until they can get a token from the token queue.
        with ops.control_dependencies(train_ops):
          token = sync_token_queue.dequeue()
        train_op = state_ops.assign(self._local_step, token)

        with ops.control_dependencies([update_op]):
          # Sync_op needs to insert tokens to the token queue at the end of the
          # step so the replicas can fetch them to start the next step.
          tokens = array_ops.fill([self._tokens_per_step], global_step)
          sync_op = sync_token_queue.enqueue_many((tokens,))

        if self._variable_averages is not None:
          with ops.control_dependencies([sync_op]), ops.name_scope(""):
            sync_op = self._variable_averages.apply(
                self._variables_to_average)

        self._chief_queue_runner = queue_runner.QueueRunner(dummy_queue,
                                                            [sync_op])
      for accum, dev in self._accumulator_list:
        with ops.device(dev):
          chief_init_ops.append(
              accum.set_global_step(
                  global_step, name="SetGlobalStep"))
      self.chief_init_op = control_flow_ops.group(*(chief_init_ops))
      self._gradients_applied = True
      return train_op
  def _restore_checkpoint(self,
                          master,
                          saver=None,
                          checkpoint_dir=None,
                          checkpoint_filename_with_path=None,
                          wait_for_checkpoint=False,
                          max_wait_secs=7200,
                          config=None):
    """Creates a `Session`, and tries to restore a checkpoint.


    Args:
      master: `String` representation of the TensorFlow master to use.
      saver: A `Saver` object used to restore a model.
      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
        dir will be used to restore.
      checkpoint_filename_with_path: Full file name path to the checkpoint file.
      wait_for_checkpoint: Whether to wait for checkpoint to become available.
      max_wait_secs: Maximum time to wait for checkpoints to become available.
      config: Optional `ConfigProto` proto used to configure the session.

    Returns:
      A pair (sess, is_restored) where 'is_restored' is `True` if
      the session could be restored, `False` otherwise.

    Raises:
      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
        set.
    """
    self._target = master
    sess = session.Session(self._target, graph=self._graph, config=config)
    # TODO(jhseu): Delete once tpu.initialize_system() goes away.
    initialize_ops = (
        distribution_strategy_context.get_distribution_strategy().initialize()
    )
    if initialize_ops:
      sess.run(initialize_ops)

    if checkpoint_dir and checkpoint_filename_with_path:
      raise ValueError("Can not provide both checkpoint_dir and "
                       "checkpoint_filename_with_path.")
    # If either saver or checkpoint_* is not specified, cannot restore. Just
    # return.
    if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
      return sess, False

    if checkpoint_filename_with_path:
      saver.restore(sess, checkpoint_filename_with_path)
      return sess, True

    # Waits up until max_wait_secs for checkpoint to become available.
    wait_time = 0
    ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
    while not ckpt or not ckpt.model_checkpoint_path:
      if wait_for_checkpoint and wait_time < max_wait_secs:
        logging.info("Waiting for checkpoint to be available.")
        time.sleep(self._recovery_wait_secs)
        wait_time += self._recovery_wait_secs
        ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
      else:
        return sess, False

    # Loads the checkpoint.
    saver.restore(sess, ckpt.model_checkpoint_path)
    saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
    return sess, True
예제 #27
0
        def compute_gradients(optimizer,
                              loss,
                              var_list=None,
                              gate_gradients=Optimizer.GATE_OP,
                              aggregation_method=None,
                              colocate_gradients_with_ops=False,
                              grad_loss=None):
            if callable(loss):
                from tensorflow.python.eager import backprop
                with backprop.GradientTape() as tape:
                    if var_list is not None:
                        tape.watch(var_list)
                    loss_value = loss()

                    # Scale loss if using a "mean" loss reduction and multiple towers.
                    # Have to be careful to call distribute_lib.get_loss_reduction()
                    # *after* loss() is evaluated, so we know what loss reduction it uses.
                    # TODO(josh11b): Test that we handle weight decay in a reasonable way.
                    if (distribute_lib.get_loss_reduction() ==
                            variable_scope.VariableAggregation.MEAN):
                        num_towers = distribution_strategy_context.get_distribution_strategy(
                        ).num_towers
                        if num_towers > 1:
                            loss_value *= (1. / num_towers)

                if var_list is None:
                    var_list = tape.watched_variables()
                # TODO(jhseu): Figure out why GradientTape's gradients don't require loss
                # to be executed.
                with ops.control_dependencies([loss_value]):
                    grads = tape.gradient(loss_value, var_list, grad_loss)
                return list(zip(grads, var_list))

            # Non-callable/Tensor loss case
            if context.executing_eagerly():
                raise RuntimeError(
                    "`loss` passed to Optimizer.compute_gradients should "
                    "be a function when eager execution is enabled.")

            # Scale loss if using a "mean" loss reduction and multiple towers.
            if (distribute_lib.get_loss_reduction() ==
                    variable_scope.VariableAggregation.MEAN):
                num_towers = distribution_strategy_context.get_distribution_strategy(
                ).num_towers
                if num_towers > 1:
                    loss *= (1. / num_towers)

            if gate_gradients not in [
                    Optimizer.GATE_NONE, Optimizer.GATE_OP,
                    Optimizer.GATE_GRAPH
            ]:
                raise ValueError(
                    "gate_gradients must be one of: Optimizer.GATE_NONE, "
                    "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
                    gate_gradients)
            optimizer._assert_valid_dtypes([loss])
            if grad_loss is not None:
                optimizer._assert_valid_dtypes([grad_loss])
            if var_list is None:
                var_list = (variables.trainable_variables() +
                            ops.get_collection(
                                ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
            else:
                var_list = nest.flatten(var_list)
            # pylint: disable=protected-access
            var_list += ops.get_collection(
                ops.GraphKeys._STREAMING_MODEL_PORTS)
            # pylint: enable=protected-access
            from tensorflow.python.training.optimizer import _get_processor
            processors = [_get_processor(v) for v in var_list]
            if not var_list:
                raise ValueError("No variables to optimize.")
            var_refs = [p.target() for p in processors]
            # original gradients computation
            # grads = tf.gradients(
            #     loss, var_refs, grad_ys=grad_loss,
            #     gate_gradients=(gate_gradients == Optimizer.GATE_OP),
            #     aggregation_method=aggregation_method,
            #     colocate_gradients_with_ops=colocate_gradients_with_ops)
            # using gradient check-pointing
            from memory_saving_gradients import gradients
            # setting outputs of different networks
            tensors_to_checkpoint = self.get_tensors_to_checkpoint()

            # just specifying memory as parameter fails
            grads = gradients(
                loss,
                var_refs,
                grad_ys=grad_loss,
                gate_gradients=(gate_gradients == Optimizer.GATE_OP),
                aggregation_method=aggregation_method,
                colocate_gradients_with_ops=colocate_gradients_with_ops,
                checkpoints='speed')

            if gate_gradients == Optimizer.GATE_GRAPH:
                grads = control_flow_ops.tuple(grads)
            grads_and_vars = list(zip(grads, var_list))
            optimizer._assert_valid_dtypes([
                v for g, v in grads_and_vars
                if g is not None and v.dtype != dtypes.resource
            ])
            return grads_and_vars