예제 #1
0
    def model_fn(features, labels, mode):
        """model_fn for keras Estimator."""
        # Raise an error when users use DistributionStrategy with native Keras
        # optimizers. Currently we only support native TensorFlow optimizers.
        if distribution_strategy_context.has_distribution_strategy() and \
            not isinstance(keras_model.optimizer,
                           (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
            raise ValueError(
                'Only TensorFlow native optimizers are supported with '
                'DistributionStrategy.')

        model = _clone_and_build_model(mode, keras_model, custom_objects,
                                       features, labels)
        model_output_names = []
        # We need to make sure that the output names of the last layer in the model
        # is the same for each of the cloned models. This is required for mirrored
        # strategy when we call regroup.
        if distribution_strategy_context.has_distribution_strategy():
            for name in model.output_names:
                name = re.compile(r'_\d$').sub('', name)
                model_output_names.append(name)
        else:
            model_output_names = model.output_names

        # Get inputs to EstimatorSpec
        predictions = dict(zip(model_output_names, model.outputs))

        loss = None
        train_op = None
        eval_metric_ops = None

        # Set loss and metric only during train and evaluate.
        if mode is not model_fn_lib.ModeKeys.PREDICT:
            if mode is model_fn_lib.ModeKeys.TRAIN:
                model._make_train_function()  # pylint: disable=protected-access
            else:
                model._make_test_function()  # pylint: disable=protected-access
            loss = model.total_loss

            eval_metric_ops = _convert_keras_metrics_to_estimator(model)

        # Set train_op only during train.
        if mode is model_fn_lib.ModeKeys.TRAIN:
            train_op = model.train_function.updates_op

        if not model._is_graph_network:
            # Reset model state to original state,
            # to avoid `model_fn` being destructive for the initial model argument.
            models.in_place_subclassed_model_state_restoration(keras_model)
        return model_fn_lib.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            loss=loss,
            train_op=train_op,
            eval_metric_ops=eval_metric_ops,
            export_outputs={
                _DEFAULT_SERVING_KEY:
                export_lib.export_output.PredictOutput(predictions)
            })
예제 #2
0
  def apply_gradients(self, grads_and_vars, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        `compute_gradients()`.
      name: Optional name for the returned operation.  Default to the name
        passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step`
      was not None, that operation also increments `global_step`.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
    """
    grads_and_vars = _filter_grads(grads_and_vars)
    var_list = [v for (_, v) in grads_and_vars]
    if distribute_ctx.has_distribution_strategy():
      reduced_grads = merge_grads(grads_and_vars)
      grads_and_vars = zip(reduced_grads, var_list)

    self._prepare()
    with ops.init_scope():
      self._create_slots(var_list)
    update_ops = []

    def update_grad_to_var(grad, var):
      """Apply gradient to variable."""
      if isinstance(var, ops.Tensor):
        raise NotImplementedError("Trying to update a Tensor ", var)
      if isinstance(grad, ops.IndexedSlices):
        if var.constraint is not None:
          raise RuntimeError(
              "Cannot use a constraint function on a sparse variable.")
        return self._resource_apply_sparse_duplicate_indices(
            grad.values, var, grad.indices)
      update_op = self._resource_apply_dense(grad, var)
      if var.constraint is not None:
        with ops.control_dependencies([update_op]):
          return var.assign(var.constraint(var))
      else:
        return update_op

    with ops.name_scope(name, self._name) as name:
      for grad, var in grads_and_vars:
        scope_name = ("" if ops.executing_eagerly_outside_functions() else
                      "_" + var.op.name)
        with ops.name_scope("update" + scope_name):
          update_ops.append(update_grad_to_var(grad, var))
      # control dependencies does not work in per replica mode, please change
      # this once b/118841692 is fixed.
      # with ops.control_dependencies(update_ops):
      #   apply_updates = self._iterations.assign_add(1).op
      apply_updates = merge_update_step(update_ops, self.iterations)
      return apply_updates
예제 #3
0
    def get_updates(self, loss, params):
        if distribution_strategy_context.has_distribution_strategy():
            self.updates = []

            if not params:
                # After the model vars have been created, the second call to get_updates
                # is called with params as an empty list. This ensures that we call
                # compute_gradients with params=None.
                grads = self.optimizer.compute_gradients(loss)
            else:
                grads = self.optimizer.compute_gradients(loss, params)
            global_step = training_util.get_global_step()
            opt_update = self.optimizer.apply_gradients(grads, global_step)
        else:
            if not params:
                self.updates = [state_ops.assign_add(self.iterations, 1)]
                return self.updates

            # Updates list starts out empty because the iterations variable is
            # incremented in optimizer.apply_gradients()
            self.updates = []
            grads = self.optimizer.compute_gradients(loss, params)
            opt_update = self.optimizer.apply_gradients(
                grads, global_step=self.iterations)

        self.updates.append(opt_update)
        return self.updates
예제 #4
0
    def _assign_moving_average(self, variable, value, momentum):
        with ops.name_scope(None, 'AssignMovingAvg',
                            [variable, value, momentum]) as scope:
            # TODO(b/120571621): We want to avoid colocating the variables here
            # since TPUStrategy does not implement replica local variables.
            # Remove this hack once we support TPULocalVariables.
            is_tpu_strategy = False
            if distribution_strategy_context.has_distribution_strategy():
                distribute = distribution_strategy_context.get_distribution_strategy(
                )
                if distribute.__class__.__name__ == 'TPUStrategy':
                    is_tpu_strategy = True

            # TODO(apassos,srbs,skyewm): the colocation constraints here are disabled
            # because of a bug which leads cond_v2/while_v2 to skip rewriting them
            # creating conflicts.
            if (control_flow_util.EnableControlFlowV2(ops.get_default_graph())
                    or is_tpu_strategy):
                cm = contextlib.contextmanager(lambda: (yield))()
            else:
                cm = ops.colocate_with(variable)
            with cm:
                decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
                if decay.dtype != variable.dtype.base_dtype:
                    decay = math_ops.cast(decay, variable.dtype.base_dtype)
                update_delta = (variable -
                                math_ops.cast(value, variable.dtype)) * decay
                return state_ops.assign_sub(variable, update_delta, name=scope)
예제 #5
0
def _assert_in_default_state(t):
  t.assertIs(distribution_strategy_context._get_default_replica_context(),
             distribution_strategy_context.get_replica_context())
  t.assertIs(None, distribution_strategy_context.get_cross_replica_context())
  t.assertFalse(distribution_strategy_context.in_cross_replica_context())
  t.assertIs(distribution_strategy_context._get_default_distribution_strategy(),
             distribution_strategy_context.get_distribution_strategy())
  t.assertFalse(distribution_strategy_context.has_distribution_strategy())
예제 #6
0
 def merge_fn(dist, s):
   self.assertIs(
       distribution_strategy_context._get_default_distribution_strategy(),
       dist)
   self.assertIs(None, distribution_strategy_context.get_replica_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_cross_replica_context())
   self.assertTrue(distribution_strategy_context.in_cross_replica_context())
   self.assertIs(dist,
                 distribution_strategy_context.get_distribution_strategy())
   self.assertFalse(
       distribution_strategy_context.has_distribution_strategy())
   return "foo_" + s
예제 #7
0
 def run_fn():
   replica_context = distribution_strategy_context.get_replica_context()
   self.assertTrue(replica_context is not None)
   self.assertIs(None,
                 distribution_strategy_context.get_cross_replica_context())
   self.assertFalse(distribution_strategy_context.in_cross_replica_context())
   self.assertTrue(distribution_strategy_context.has_distribution_strategy())
   self.assertIs(dist,
                 distribution_strategy_context.get_distribution_strategy())
   self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
예제 #8
0
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribution_strategy_context.get_replica_context())
     self.assertIs(dist,
                   distribution_strategy_context.get_cross_replica_context())
     self.assertTrue(distribution_strategy_context.in_cross_replica_context())
     self.assertTrue(distribution_strategy_context.has_distribution_strategy())
     self.assertIs(dist,
                   distribution_strategy_context.get_distribution_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
예제 #9
0
    def _assign_moving_average(self, variable, value, momentum):
        with ops.name_scope(None, 'AssignMovingAvg',
                            [variable, value, momentum]) as scope:
            # TODO(b/120571621): We want to avoid colocating the variables here
            # since TPUStrategy does not implement replica local variables.
            # Remove this hack once we support TPULocalVariables.
            is_tpu_strategy = False
            if distribution_strategy_context.has_distribution_strategy():
                distribute = distribution_strategy_context.get_distribution_strategy(
                )
                if distribute.__class__.__name__ == 'TPUStrategy':
                    is_tpu_strategy = True

            with ops.colocate_with(variable):
                decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
                if decay.dtype != variable.dtype.base_dtype:
                    decay = math_ops.cast(decay, variable.dtype.base_dtype)
                update_delta = (variable -
                                math_ops.cast(value, variable.dtype)) * decay
                return state_ops.assign_sub(variable, update_delta, name=scope)
예제 #10
0
def _var_key(var):
    """Key for representing a primary variable, for looking up slots.

  In graph mode the name is derived from the var shared name.
  In eager mode the name is derived from the var unique id.
  If distribution strategy exists, get the primary variable first.

  Args:
    var: the variable.

  Returns:
    the unique name of the variable.
  """

    # pylint: disable=protected-access
    if distribute_ctx.has_distribution_strategy() and hasattr(
            var, "_primary_var"):
        var = var._primary_var
    if hasattr(var, "op"):
        return var._shared_name
    return var._unique_id
예제 #11
0
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        `compute_gradients()`.
      global_step: Optional `Variable` to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step`
      was not None, that operation also increments `global_step`.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
      RuntimeError: If you should use `_distributed_apply()` instead.
    """
    # This is a default implementation of apply_gradients() that can be shared
    # by most optimizers.  It relies on the subclass implementing the following
    # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().

    # Handle DistributionStrategy case.
    if distribute_ctx.get_cross_replica_context():
      raise RuntimeError("Use `_distributed_apply()` instead of "
                         "`apply_gradients()` in a cross-replica context.")
    # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by
    # always calling _distributed_apply(), using the default distribution
    # as needed.
    if distribute_ctx.has_distribution_strategy():
      grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
      return distribute_ctx.get_replica_context().merge_call(
          self._distributed_apply, args=(grads_and_vars, global_step, name))

    # No DistributionStrategy case.
    grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
    if not grads_and_vars:
      raise ValueError("No variables provided.")
    converted_grads_and_vars = []
    for g, v in grads_and_vars:
      if g is not None:
        try:
          # Convert the grad to Tensor or IndexedSlices if necessary.
          g = ops.convert_to_tensor_or_indexed_slices(g)
        except TypeError:
          raise TypeError(
              "Gradient must be convertible to a Tensor"
              " or IndexedSlices, or None: %s" % g)
        if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
          raise TypeError(
              "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
      p = _get_processor(v)
      converted_grads_and_vars.append((g, v, p))

    converted_grads_and_vars = tuple(converted_grads_and_vars)
    var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
    if not var_list:
      raise ValueError("No gradients provided for any variable: %s." %
                       ([str(v) for _, v, _ in converted_grads_and_vars],))
    with ops.init_scope():
      self._create_slots(var_list)
    update_ops = []
    with ops.name_scope(name, self._name) as name:
      self._prepare()
      for grad, var, processor in converted_grads_and_vars:
        if grad is None:
          continue
        # We colocate all ops created in _apply_dense or _apply_sparse
        # on the same device as the variable.
        # TODO(apassos): figure out how to get the variable name here.
        if context.executing_eagerly() or isinstance(
            var,
            resource_variable_ops.ResourceVariable) and not var._in_graph_mode:  # pylint: disable=protected-access
          scope_name = ""
        else:
          scope_name = var.op.name
        with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
          update_ops.append(processor.update_op(self, grad))
      if global_step is None:
        apply_updates = self._finish(update_ops, name)
      else:
        with ops.control_dependencies([self._finish(update_ops, "update")]):
          with ops.colocate_with(global_step):
            if isinstance(global_step, resource_variable_ops.ResourceVariable):
              # TODO(apassos): the implicit read in assign_add is slow; consider
              # making it less so.
              apply_updates = resource_variable_ops.assign_add_variable_op(
                  global_step.handle,
                  ops.convert_to_tensor(1, dtype=global_step.dtype),
                  name=name)
            else:
              apply_updates = state_ops.assign_add(global_step, 1, name=name)

      if not context.executing_eagerly():
        if isinstance(apply_updates, ops.Tensor):
          apply_updates = apply_updates.op
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        if apply_updates not in train_op:
          train_op.append(apply_updates)

      return apply_updates