Пример #1
0
  def step_fn(ctx, inputs, targets):
    """Clones the model and calls make_train_function."""
    # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes.
    clone_model_on_towers(
        model,
        current_strategy,
        make_callback_model=True,
        inputs=inputs,
        targets=targets)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_tower(
         _per_device_train_function, model._grouped_model)
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs,
         grouped_updates, grouped_session_args, with_loss_tensor=True)
    combined_fn = K.Function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_train_function',
        **all_session_args)

    # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be
    # something else for different outputs.
    out_labels = model.metrics_names or []
    for label, output in zip(out_labels, combined_fn.outputs):
      ctx.set_last_step_output(label, output,
                               aggregation=distribute_lib.get_loss_reduction())

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op
Пример #2
0
 def _scale_loss(loss_value):
   if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
     num_replicas = \
       distribute_ctx.get_distribution_strategy().num_replicas_in_sync
     if num_replicas > 1:
       loss_value *= (1. / num_replicas)
   return loss_value
Пример #3
0
 def _scale_loss(loss_value):
   if (distribute_lib.get_loss_reduction() ==
       variable_scope.VariableAggregation.MEAN):
     num_replicas = \
       distribute_ctx.get_distribution_strategy().num_replicas_in_sync
     if num_replicas > 1:
       loss_value *= (1. / num_replicas)
   return loss_value
Пример #4
0
  def step_fn(ctx, inputs):
    """Clones the model and calls make_fit_function."""
    # TODO(priyag, sourabhbajaj): The model gets cloned every time
    # fit/test/predict is called. We should look into caching this keyed on
    # input shapes.
    inputs, targets = inputs
    clone_model_on_replicas(
        model,
        current_strategy,
        make_callback_model=True,
        inputs=inputs,
        targets=targets,
        mode=_Mode.TRAIN)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_replica(
         _per_device_fit_function, args=(model._grouped_model_train,))
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs,
         grouped_updates, grouped_session_args)
    combined_fn = K.function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_fit_function',
        **all_session_args)

    for label, output in zip(out_labels, combined_fn.outputs):
      if label == 'loss':
        reduce_op = distribute_lib.get_loss_reduction()
      else:
        # We reduce all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        reduce_op = ds_reduce_util.ReduceOp.MEAN
      ctx.set_last_step_output(label, output, reduce_op)

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op
Пример #5
0
    def step_fn(ctx, inputs):
        """Clones the model and calls make_fit_function."""
        # TODO(priyag, sourabhbajaj): The model gets cloned every time
        # fit/test/predict is called. We should look into caching this keyed on
        # input shapes.
        inputs, targets = inputs
        clone_model_on_replicas(model,
                                current_strategy,
                                make_callback_model=True,
                                inputs=inputs,
                                targets=targets,
                                mode=_Mode.TRAIN)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_replica(
             _per_device_fit_function, args=(model._grouped_model_train, ))
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)
        combined_fn = K.function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_fit_function',
                                 **all_session_args)

        for label, output in zip(out_labels, combined_fn.outputs):
            if label == 'loss':
                reduce_op = distribute_lib.get_loss_reduction()
            else:
                # We reduce all other metrics using mean for now. This is temporary
                # workaround until new metrics are in place.
                reduce_op = ds_reduce_util.ReduceOp.MEAN
            ctx.set_last_step_output(label, output, reduce_op)

        # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
        # feed_dict, session kwargs, run options, run_metadata for now. These should
        # be handled appropriately
        return combined_fn.updates_op
  def step_fn(ctx, inputs, targets):
    """Clones the model and calls make_train_function."""
    # TODO(priyag, sourabhbajaj): The model gets cloned every time
    # fit/test/predict is called. We should look into caching this keyed on
    # input shapes.
    clone_model_on_towers(
        model,
        current_strategy,
        make_callback_model=True,
        inputs=inputs,
        targets=targets)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_tower(
         _per_device_train_function, model._grouped_model)
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs,
         grouped_updates, grouped_session_args)
    combined_fn = K.Function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_train_function',
        **all_session_args)

    out_labels = model.metrics_names or []
    for label, output in zip(out_labels, combined_fn.outputs):
      if label == 'loss':
        aggregation = distribute_lib.get_loss_reduction()
      else:
        # We aggregate all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        aggregation = variable_scope.VariableAggregation.MEAN
      ctx.set_last_step_output(label, output, aggregation)

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op
Пример #7
0
  def step_fn(ctx, inputs, targets):
    """Clones the model and calls make_train_function."""
    # TODO(priyag, sourabhbajaj): The model gets cloned every time
    # fit/test/predict is called. We should look into caching this keyed on
    # input shapes.
    clone_model_on_towers(
        model,
        current_strategy,
        make_callback_model=True,
        inputs=inputs,
        targets=targets)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_tower(
         _per_device_train_function, model._grouped_model)
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs,
         grouped_updates, grouped_session_args)
    combined_fn = K.Function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_train_function',
        **all_session_args)

    out_labels = model.metrics_names or []
    for label, output in zip(out_labels, combined_fn.outputs):
      if label == 'loss':
        aggregation = distribute_lib.get_loss_reduction()
      else:
        # We aggregate all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        aggregation = variable_scope.VariableAggregation.MEAN
      ctx.set_last_step_output(label, output, aggregation)

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op
  def step_fn(ctx, inputs, targets):
    """Clones the model and calls make_eval_function."""
    # TODO(priyag, sourabhbajaj): The model gets cloned every time
    # fit/test/predict is called. We should look into caching this keyed on
    # input shapes.
    clone_model_on_replicas(
        model,
        current_strategy,
        make_callback_model=False,
        inputs=inputs,
        targets=targets,
        mode=_Mode.TEST)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_replica(
         _per_device_eval_function, args=(model._grouped_model_test,))

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)

    combined_fn = K.function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_test_function',
        **all_session_args)

    for label, output in zip(model.metrics_names, combined_fn.outputs):
      if label == 'loss':
        aggregation = distribute_lib.get_loss_reduction()
      else:
        # We aggregate all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        aggregation = variable_scope.VariableAggregation.MEAN
      ctx.set_last_step_output(label, output, aggregation)

    return combined_fn.updates_op
Пример #9
0
  def step_fn(ctx, inputs, targets):
    """Clones the model and calls make_eval_function."""
    # TODO(priyag, sourabhbajaj): The model gets cloned every time
    # fit/test/predict is called. We should look into caching this keyed on
    # input shapes.
    clone_model_on_replicas(
        model,
        current_strategy,
        make_callback_model=False,
        inputs=inputs,
        targets=targets,
        mode=_Mode.TEST)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.call_for_each_replica(
         _per_device_eval_function, model._grouped_model_test)

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)

    combined_fn = K.function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_test_function',
        **all_session_args)

    for label, output in zip(model.metrics_names, combined_fn.outputs):
      if label == 'loss':
        aggregation = distribute_lib.get_loss_reduction()
      else:
        # We aggregate all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        aggregation = variable_scope.VariableAggregation.MEAN
      ctx.set_last_step_output(label, output, aggregation)

    return combined_fn.updates_op
Пример #10
0
    def step_fn(ctx, inputs, targets):
        """Clones the model and calls make_train_function."""
        # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes.
        clone_model_on_towers(model,
                              current_strategy,
                              make_callback_model=True,
                              inputs=inputs,
                              targets=targets)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = current_strategy.call_for_each_tower(
             _per_device_train_function, model._grouped_model)
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy,
             grouped_inputs,
             grouped_outputs,
             grouped_updates,
             grouped_session_args,
             with_loss_tensor=True)
        combined_fn = K.Function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_train_function',
                                 **all_session_args)

        # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be
        # something else for different outputs.
        out_labels = model.metrics_names or []
        for label, output in zip(out_labels, combined_fn.outputs):
            ctx.set_last_step_output(
                label, output, aggregation=distribute_lib.get_loss_reduction())

        # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
        # feed_dict, session kwargs, run options, run_metadata for now. These should
        # be handled appropriately
        return combined_fn.updates_op
def unwrap_values(distribution_strategy,
                  grouped_inputs,
                  grouped_outputs,
                  grouped_updates,
                  grouped_session_args,
                  with_loss_tensor=False):
    """Unwrap and return the list of values contained in the PerDevice parameters.

  This function calls `flatten_perdevice_values` to parse each of the input
  parameters into a list of values on the different devices. If we set
  `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
  the different devices to give us one loss tensor.

  Args:
    distribution_strategy: DistributionStrategy used to distribute training and
        validation.
    grouped_inputs: PerDevice inputs returned from the train or test function
        that we ran on each device.
    grouped_outputs: PerDevice outputs returned from the train or test function
        that we ran on each device.
    grouped_updates: PerDevice updates returned from the train or test function
        that we ran on each device.
    grouped_session_args: PerDevice session args returned from the train or
        test function that we ran on each device.
    with_loss_tensor: Boolean that indicates if we need to add the reduced loss
        tensor as one of the outputs.

  Returns:
    Values of each of the PerDevice parameters.

  """
    # Unwrap per device values returned from each model's train function.
    # This will be used to construct the main train function.
    all_inputs = flatten_perdevice_values(distribution_strategy,
                                          grouped_inputs)
    if with_loss_tensor:
        # reduce loss tensor before adding it to the list of fetches
        loss = distribution_strategy.unwrap(
            distribution_strategy.reduce(distribute_lib.get_loss_reduction(),
                                         grouped_outputs[0],
                                         destinations='/device:CPU:0'))[0]

        all_outputs = flatten_perdevice_values(distribution_strategy,
                                               grouped_outputs[1:])
        all_outputs = [loss] + all_outputs
    else:
        all_outputs = flatten_perdevice_values(distribution_strategy,
                                               grouped_outputs)

    all_updates = flatten_perdevice_values(distribution_strategy,
                                           grouped_updates)

    all_session_args = {}
    grouped_feed_dict = grouped_session_args.get('feed_dict')
    if grouped_feed_dict:
        all_session_args['feed_dict'] = flatten_perdevice_values(
            distribution_strategy, grouped_feed_dict)

    grouped_fetches = grouped_session_args.get('fetches')
    if grouped_fetches:
        all_session_args['fetches'] = flatten_perdevice_values(
            distribution_strategy, grouped_fetches)

    return all_inputs, all_outputs, all_updates, all_session_args
def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
                  grouped_updates, grouped_session_args,
                  with_loss_tensor=False):
  """Unwrap and return the list of values contained in the PerDevice parameters.

  This function calls `flatten_perdevice_values` to parse each of the input
  parameters into a list of values on the different devices. If we set
  `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
  the different devices to give us one loss tensor.

  Args:
    distribution_strategy: DistributionStrategy used to distribute training and
        validation.
    grouped_inputs: PerDevice inputs returned from the train or test function
        that we ran on each device.
    grouped_outputs: PerDevice outputs returned from the train or test function
        that we ran on each device.
    grouped_updates: PerDevice updates returned from the train or test function
        that we ran on each device.
    grouped_session_args: PerDevice session args returned from the train or
        test function that we ran on each device.
    with_loss_tensor: Boolean that indicates if we need to add the reduced loss
        tensor as one of the outputs.

  Returns:
    Values of each of the PerDevice parameters.

  """
  # Unwrap per device values returned from each model's train function.
  # This will be used to construct the main train function.
  all_inputs = flatten_perdevice_values(distribution_strategy,
                                        grouped_inputs)
  if with_loss_tensor:
    # reduce loss tensor before adding it to the list of fetches
    loss = distribution_strategy.unwrap(
        distribution_strategy.reduce(distribute_lib.get_loss_reduction(),
                                     grouped_outputs[0],
                                     destinations='/device:CPU:0'))[0]

    all_outputs = flatten_perdevice_values(distribution_strategy,
                                           grouped_outputs[1:])
    all_outputs = [loss] + all_outputs
  else:
    all_outputs = flatten_perdevice_values(distribution_strategy,
                                           grouped_outputs)

  all_updates = flatten_perdevice_values(distribution_strategy,
                                         grouped_updates)

  all_session_args = {}
  grouped_feed_dict = grouped_session_args.get('feed_dict')
  if grouped_feed_dict:
    all_session_args['feed_dict'] = flatten_perdevice_values(
        distribution_strategy, grouped_feed_dict)

  grouped_fetches = grouped_session_args.get('fetches')
  if grouped_fetches:
    all_session_args['fetches'] = flatten_perdevice_values(
        distribution_strategy, grouped_fetches)

  return all_inputs, all_outputs, all_updates, all_session_args
Пример #13
0
  def compute_gradients(self, loss, var_list=None,
                        gate_gradients=GATE_OP,
                        aggregation_method=None,
                        colocate_gradients_with_ops=False,
                        grad_loss=None):
    """Compute gradients of `loss` for the variables in `var_list`.

    This is the first part of `minimize()`.  It returns a list
    of (gradient, variable) pairs where "gradient" is the gradient
    for "variable".  Note that "gradient" can be a `Tensor`, an
    `IndexedSlices`, or `None` if there is no gradient for the
    given variable.

    Args:
      loss: A Tensor containing the value to minimize or a callable taking
        no arguments which returns the value to minimize. When eager execution
        is enabled it must be a callable.
      var_list: Optional list or tuple of `tf.Variable` to update to minimize
        `loss`.  Defaults to the list of variables collected in the graph
        under the key `GraphKeys.TRAINABLE_VARIABLES`.
      gate_gradients: How to gate the computation of gradients.  Can be
        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
      aggregation_method: Specifies the method used to combine gradient terms.
        Valid values are defined in the class `AggregationMethod`.
      colocate_gradients_with_ops: If True, try colocating gradients with
        the corresponding op.
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.

    Returns:
      A list of (gradient, variable) pairs. Variable is always present, but
      gradient can be `None`.

    Raises:
      TypeError: If `var_list` contains anything else than `Variable` objects.
      ValueError: If some arguments are invalid.
      RuntimeError: If called with eager execution enabled and `loss` is
        not callable.

    @compatibility(eager)
    When eager execution is enabled, `gate_gradients`, `aggregation_method`,
    and `colocate_gradients_with_ops` are ignored.
    @end_compatibility
    """
    if callable(loss):
      with backprop.GradientTape() as tape:
        if var_list is not None:
          tape.watch(var_list)
        loss_value = loss()

        # Scale loss if using a "mean" loss reduction and multiple towers.
        # Have to be careful to call distribute_lib.get_loss_reduction()
        # *after* loss() is evaluated, so we know what loss reduction it uses.
        # TODO(josh11b): Test that we handle weight decay in a reasonable way.
        if (distribute_lib.get_loss_reduction() ==
            variable_scope.VariableAggregation.MEAN):
          num_towers = distribution_strategy_context.get_distribution_strategy(
          ).num_towers
          if num_towers > 1:
            loss_value *= (1. / num_towers)

      if var_list is None:
        var_list = tape.watched_variables()
      grads = tape.gradient(loss_value, var_list, grad_loss)
      return list(zip(grads, var_list))

    # Non-callable/Tensor loss case
    if context.executing_eagerly():
      raise RuntimeError(
          "`loss` passed to Optimizer.compute_gradients should "
          "be a function when eager execution is enabled.")

    # Scale loss if using a "mean" loss reduction and multiple towers.
    if (distribute_lib.get_loss_reduction() ==
        variable_scope.VariableAggregation.MEAN):
      num_towers = distribution_strategy_context.get_distribution_strategy(
      ).num_towers
      if num_towers > 1:
        loss *= (1. / num_towers)

    if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
                              Optimizer.GATE_GRAPH]:
      raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
                       "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
                       gate_gradients)
    self._assert_valid_dtypes([loss])
    if grad_loss is not None:
      self._assert_valid_dtypes([grad_loss])
    if var_list is None:
      var_list = (
          variables.trainable_variables() +
          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
    else:
      var_list = nest.flatten(var_list)
    # pylint: disable=protected-access
    var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
    # pylint: enable=protected-access
    processors = [_get_processor(v) for v in var_list]
    if not var_list:
      raise ValueError("No variables to optimize.")
    var_refs = [p.target() for p in processors]
    grads = gradients.gradients(
        loss, var_refs, grad_ys=grad_loss,
        gate_gradients=(gate_gradients == Optimizer.GATE_OP),
        aggregation_method=aggregation_method,
        colocate_gradients_with_ops=colocate_gradients_with_ops)
    if gate_gradients == Optimizer.GATE_GRAPH:
      grads = control_flow_ops.tuple(grads)
    grads_and_vars = list(zip(grads, var_list))
    self._assert_valid_dtypes(
        [v for g, v in grads_and_vars
         if g is not None and v.dtype != dtypes.resource])
    return grads_and_vars
Пример #14
0
        def compute_gradients(optimizer,
                              loss,
                              var_list=None,
                              gate_gradients=Optimizer.GATE_OP,
                              aggregation_method=None,
                              colocate_gradients_with_ops=False,
                              grad_loss=None):
            if callable(loss):
                from tensorflow.python.eager import backprop
                with backprop.GradientTape() as tape:
                    if var_list is not None:
                        tape.watch(var_list)
                    loss_value = loss()

                    # Scale loss if using a "mean" loss reduction and multiple towers.
                    # Have to be careful to call distribute_lib.get_loss_reduction()
                    # *after* loss() is evaluated, so we know what loss reduction it uses.
                    # TODO(josh11b): Test that we handle weight decay in a reasonable way.
                    if (distribute_lib.get_loss_reduction() ==
                            variable_scope.VariableAggregation.MEAN):
                        num_towers = distribution_strategy_context.get_distribution_strategy(
                        ).num_towers
                        if num_towers > 1:
                            loss_value *= (1. / num_towers)

                if var_list is None:
                    var_list = tape.watched_variables()
                # TODO(jhseu): Figure out why GradientTape's gradients don't require loss
                # to be executed.
                with ops.control_dependencies([loss_value]):
                    grads = tape.gradient(loss_value, var_list, grad_loss)
                return list(zip(grads, var_list))

            # Non-callable/Tensor loss case
            if context.executing_eagerly():
                raise RuntimeError(
                    "`loss` passed to Optimizer.compute_gradients should "
                    "be a function when eager execution is enabled.")

            # Scale loss if using a "mean" loss reduction and multiple towers.
            if (distribute_lib.get_loss_reduction() ==
                    variable_scope.VariableAggregation.MEAN):
                num_towers = distribution_strategy_context.get_distribution_strategy(
                ).num_towers
                if num_towers > 1:
                    loss *= (1. / num_towers)

            if gate_gradients not in [
                    Optimizer.GATE_NONE, Optimizer.GATE_OP,
                    Optimizer.GATE_GRAPH
            ]:
                raise ValueError(
                    "gate_gradients must be one of: Optimizer.GATE_NONE, "
                    "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
                    gate_gradients)
            optimizer._assert_valid_dtypes([loss])
            if grad_loss is not None:
                optimizer._assert_valid_dtypes([grad_loss])
            if var_list is None:
                var_list = (variables.trainable_variables() +
                            ops.get_collection(
                                ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
            else:
                var_list = nest.flatten(var_list)
            # pylint: disable=protected-access
            var_list += ops.get_collection(
                ops.GraphKeys._STREAMING_MODEL_PORTS)
            # pylint: enable=protected-access
            from tensorflow.python.training.optimizer import _get_processor
            processors = [_get_processor(v) for v in var_list]
            if not var_list:
                raise ValueError("No variables to optimize.")
            var_refs = [p.target() for p in processors]
            # original gradients computation
            # grads = tf.gradients(
            #     loss, var_refs, grad_ys=grad_loss,
            #     gate_gradients=(gate_gradients == Optimizer.GATE_OP),
            #     aggregation_method=aggregation_method,
            #     colocate_gradients_with_ops=colocate_gradients_with_ops)
            # using gradient check-pointing
            from memory_saving_gradients import gradients
            # setting outputs of different networks
            tensors_to_checkpoint = self.get_tensors_to_checkpoint()

            # just specifying memory as parameter fails
            grads = gradients(
                loss,
                var_refs,
                grad_ys=grad_loss,
                gate_gradients=(gate_gradients == Optimizer.GATE_OP),
                aggregation_method=aggregation_method,
                colocate_gradients_with_ops=colocate_gradients_with_ops,
                checkpoints='speed')

            if gate_gradients == Optimizer.GATE_GRAPH:
                grads = control_flow_ops.tuple(grads)
            grads_and_vars = list(zip(grads, var_list))
            optimizer._assert_valid_dtypes([
                v for g, v in grads_and_vars
                if g is not None and v.dtype != dtypes.resource
            ])
            return grads_and_vars
Пример #15
0
    def _train_model_distributed(self, strategy, input_fn, hooks,
                                 saving_listeners, save_best_ckpt):
        """Initiate training with `input_fn`, using `DistributionStrategies`.

        Args:
          input_fn: A function that provides input data for training as minibatches.
          hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
            callbacks inside the training loop.
          saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used
            for callbacks that run immediately before or after checkpoint savings.

        Returns:
            Loss from training
        """
        strategy.configure(self._session_config)

        worker_hooks = []
        with ops.Graph().as_default() as g:
            # We want to create the iterations variable outside the distribution scope
            # as that is just stored on the host and mainly used to drive the loop
            # and doesn't need to be a Mirrored/Device variable.
            with strategy.scope():
                random_seed.set_random_seed(self._config.tf_random_seed)

                if self._train_with_eval:
                    self.handler = array_ops.placeholder(dtypes.string,
                                                         shape=(),
                                                         name="Handler")
                    iterator, self.train_iterator, self.eval_iterator, input_hooks = (
                        self._get_iterator_for_train_and_eval(
                            input_fn, self.handler, strategy))
                else:
                    self.handler, self.train_iterator, self.eval_iterator = None, None, None
                    iterator, input_hooks = self._get_iterator_from_input_fn(
                        input_fn, model_fn_lib.ModeKeys.TRAIN, strategy)
                worker_hooks.extend(input_hooks)
                global_step_tensor = self._create_and_assert_global_step(g)
                # we want to add to the global collection in the main thread not the
                # tower threads.
                ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY,
                                      strategy.read_var(global_step_tensor))

                features, labels = estimator_util.parse_iterator_result(
                    per_device_dataset(iterator, strategy.extended._devices))
                grouped_estimator_spec = strategy.call_for_each_replica(
                    self._call_model_fn,
                    args=(features, labels, model_fn_lib.ModeKeys.TRAIN,
                          self.config))
                loss = strategy.reduce(distribute_lib.get_loss_reduction(),
                                       grouped_estimator_spec.loss)
                distributed_train_op = grouped_estimator_spec.train_op

                predictions = {}
                for key, val in grouped_estimator_spec.predictions.items():
                    if key == "GlobalStep":
                        predictions["GlobalStep"] = strategy.unwrap(val)[0]
                    elif "/" in key:
                        predictions[key] = strategy.reduce(
                            reduce_util.ReduceOp.MEAN, val)
                    else:
                        predictions[key] = array_ops.concat(
                            strategy.unwrap(val), axis=0)

                scaffold = estimator_lib._combine_distributed_scaffold(
                    grouped_estimator_spec.scaffold, strategy)

                # add a test for unwrapping per_device_hooks.
                def get_hooks_from_the_first_device(per_device_hooks):
                    # In tensorflow-1.12 Estimator, Next line is self._distribution.unwrap()
                    # but self._distribution is not defined, which maybe a bug?
                    return [
                        strategy.unwrap(per_device_hook)[0]
                        for per_device_hook in per_device_hooks
                    ]

                training_hooks = get_hooks_from_the_first_device(
                    grouped_estimator_spec.training_hooks)
                training_chief_hooks = get_hooks_from_the_first_device(
                    grouped_estimator_spec.training_chief_hooks)
                worker_hooks.append(
                    estimator_util.StrategyInitFinalizeHook(
                        strategy.initialize, strategy.finalize))

                estimator_spec = model_fn_lib.EstimatorSpec(
                    mode=grouped_estimator_spec.mode,
                    loss=loss,
                    train_op=strategy.group(distributed_train_op),
                    predictions=predictions,
                    training_hooks=training_hooks,
                    training_chief_hooks=training_chief_hooks,
                    scaffold=scaffold)
                return self._train_with_estimator_spec(estimator_spec,
                                                       worker_hooks, hooks,
                                                       global_step_tensor,
                                                       saving_listeners,
                                                       save_best_ckpt)
Пример #16
0
  def compute_gradients(self, loss, var_list=None,
                        gate_gradients=GATE_OP,
                        aggregation_method=None,
                        colocate_gradients_with_ops=False,
                        grad_loss=None):
    """Compute gradients of `loss` for the variables in `var_list`.

    This is the first part of `minimize()`.  It returns a list
    of (gradient, variable) pairs where "gradient" is the gradient
    for "variable".  Note that "gradient" can be a `Tensor`, an
    `IndexedSlices`, or `None` if there is no gradient for the
    given variable.

    Args:
      loss: A Tensor containing the value to minimize or a callable taking
        no arguments which returns the value to minimize. When eager execution
        is enabled it must be a callable.
      var_list: Optional list or tuple of `tf.Variable` to update to minimize
        `loss`.  Defaults to the list of variables collected in the graph
        under the key `GraphKeys.TRAINABLE_VARIABLES`.
      gate_gradients: How to gate the computation of gradients.  Can be
        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
      aggregation_method: Specifies the method used to combine gradient terms.
        Valid values are defined in the class `AggregationMethod`.
      colocate_gradients_with_ops: If True, try colocating gradients with
        the corresponding op.
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.

    Returns:
      A list of (gradient, variable) pairs. Variable is always present, but
      gradient can be `None`.

    Raises:
      TypeError: If `var_list` contains anything else than `Variable` objects.
      ValueError: If some arguments are invalid.
      RuntimeError: If called with eager execution enabled and `loss` is
        not callable.

    @compatibility(eager)
    When eager execution is enabled, `gate_gradients`, `aggregation_method`,
    and `colocate_gradients_with_ops` are ignored.
    @end_compatibility
    """
    if callable(loss):
      with backprop.GradientTape() as tape:
        if var_list is not None:
          tape.watch(var_list)
        loss_value = loss()

        # Scale loss if using a "mean" loss reduction and multiple towers.
        # Have to be careful to call distribute_lib.get_loss_reduction()
        # *after* loss() is evaluated, so we know what loss reduction it uses.
        # TODO(josh11b): Test that we handle weight decay in a reasonable way.
        if (distribute_lib.get_loss_reduction() ==
            variable_scope.VariableAggregation.MEAN):
          num_towers = distribute_lib.get_distribution_strategy().num_towers
          if num_towers > 1:
            loss_value *= (1. / num_towers)

      if var_list is None:
        var_list = tape.watched_variables()
      grads = tape.gradient(loss_value, var_list, grad_loss)
      return list(zip(grads, var_list))

    # Non-callable/Tensor loss case
    if context.executing_eagerly():
      raise RuntimeError(
          "`loss` passed to Optimizer.compute_gradients should "
          "be a function when eager execution is enabled.")

    # Scale loss if using a "mean" loss reduction and multiple towers.
    if (distribute_lib.get_loss_reduction() ==
        variable_scope.VariableAggregation.MEAN):
      num_towers = distribute_lib.get_distribution_strategy().num_towers
      if num_towers > 1:
        loss *= (1. / num_towers)

    if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
                              Optimizer.GATE_GRAPH]:
      raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
                       "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
                       gate_gradients)
    self._assert_valid_dtypes([loss])
    if grad_loss is not None:
      self._assert_valid_dtypes([grad_loss])
    if var_list is None:
      var_list = (
          variables.trainable_variables() +
          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
    else:
      var_list = nest.flatten(var_list)
    # pylint: disable=protected-access
    var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
    # pylint: enable=protected-access
    processors = [_get_processor(v) for v in var_list]
    if not var_list:
      raise ValueError("No variables to optimize.")
    var_refs = [p.target() for p in processors]
    grads = gradients.gradients(
        loss, var_refs, grad_ys=grad_loss,
        gate_gradients=(gate_gradients == Optimizer.GATE_OP),
        aggregation_method=aggregation_method,
        colocate_gradients_with_ops=colocate_gradients_with_ops)
    if gate_gradients == Optimizer.GATE_GRAPH:
      grads = control_flow_ops.tuple(grads)
    grads_and_vars = list(zip(grads, var_list))
    self._assert_valid_dtypes(
        [v for g, v in grads_and_vars
         if g is not None and v.dtype != dtypes.resource])
    return grads_and_vars