예제 #1
0
 def _scale_loss(loss_value):
   if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
     num_replicas = \
       distribute_ctx.get_distribution_strategy().num_replicas_in_sync
     if num_replicas > 1:
       loss_value *= (1. / num_replicas)
   return loss_value
예제 #2
0
 def _scale_loss(loss_value):
   ops.get_default_graph()._is_loss_scaled_by_optimizer = False  # pylint: disable=protected-access
   if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
     num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
     if num_replicas > 1:
       loss_value *= (1. / num_replicas)
       ops.get_default_graph()._is_loss_scaled_by_optimizer = True  # pylint: disable=protected-access
   return loss_value
예제 #3
0
  def step_fn(ctx, inputs):
    """Clones the model and calls make_fit_function."""
    # TODO(priyag, sourabhbajaj): The model gets cloned every time
    # fit/test/predict is called. We should look into caching this keyed on
    # input shapes.
    inputs, targets = inputs
    clone_model_on_replicas(
        model,
        current_strategy,
        make_callback_model=True,
        inputs=inputs,
        targets=targets,
        mode=_Mode.TRAIN)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.extended.call_for_each_replica(
         _per_device_fit_function, args=(model._grouped_model_train,))
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs,
         grouped_updates, grouped_session_args)
    combined_fn = K.function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_fit_function',
        **all_session_args)

    for label, output in zip(out_labels, combined_fn.outputs):
      if label == 'loss':
        reduce_op = distribute_lib.get_loss_reduction()
      else:
        # We reduce all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        reduce_op = ds_reduce_util.ReduceOp.MEAN
      ctx.set_last_step_output(label, output, reduce_op)

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op
예제 #4
0
    def step_fn(ctx, inputs):
        """Clones the model and calls make_eval_function."""
        # TODO(priyag, sourabhbajaj): The model gets cloned every time
        # fit/test/predict is called. We should look into caching this keyed on
        # input shapes.
        inputs, targets = inputs
        clone_model_on_replicas(model,
                                current_strategy,
                                make_callback_model=False,
                                inputs=inputs,
                                targets=targets,
                                mode=_Mode.TEST)

        (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args
         ) = current_strategy.extended.call_for_each_replica(
             _per_device_eval_function, args=(model._grouped_model_test, ))

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)

        combined_fn = K.function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_test_function',
                                 **all_session_args)

        for label, output in zip(model.metrics_names, combined_fn.outputs):
            if label == 'loss':
                reduce_op = distribute_lib.get_loss_reduction()
            else:
                # We reduce all other metrics using mean for now. This is temporary
                # workaround until new metrics are in place.
                reduce_op = ds_reduce_util.ReduceOp.MEAN
            ctx.set_last_step_output(label, output, reduce_op)

        return combined_fn.updates_op
예제 #5
0
  def step_fn(ctx, inputs):
    """Clones the model and calls make_eval_function."""
    inputs, targets = inputs
    if model._compile_distribution:
      distributed_training_utils. clone_model_on_replicas(
          model, current_strategy,
          make_callback_model=False, inputs=inputs,
          targets=targets, mode=distributed_training_utils.ModeKeys.TEST)
    else:
      distributed_training_utils._build_distributed_network(
          model, current_strategy, inputs, targets,
          mode=distributed_training_utils.ModeKeys.TEST)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.extended.call_for_each_replica(
         _per_device_eval_function, args=(model._distributed_model_test,))

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)

    combined_fn = K.function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_test_function',
        **all_session_args)

    for label, output in zip(model.metrics_names, combined_fn.outputs):
      if label == 'loss':
        reduce_op = distribute_lib.get_loss_reduction()
      else:
        # We reduce all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        reduce_op = ds_reduce_util.ReduceOp.MEAN
      ctx.set_last_step_output(label, output, reduce_op)

    return combined_fn.updates_op
예제 #6
0
  def step_fn(ctx, inputs):
    """Clones the model and calls make_eval_function."""
    inputs, targets = inputs
    if model._compile_distribution:
      distributed_training_utils. clone_model_on_replicas(
          model, current_strategy,
          make_callback_model=False, inputs=inputs,
          targets=targets, mode=distributed_training_utils.ModeKeys.TEST)
    else:
      distributed_training_utils._build_distributed_network(
          model, current_strategy, inputs, targets,
          mode=distributed_training_utils.ModeKeys.TEST)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.extended.call_for_each_replica(
         _per_device_eval_function, args=(model._distributed_model_test,))

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)

    combined_fn = K.function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_test_function',
        **all_session_args)

    for label, output in zip(model.metrics_names, combined_fn.outputs):
      if label == 'loss':
        reduce_op = distribute_lib.get_loss_reduction()
      else:
        # We reduce all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        reduce_op = ds_reduce_util.ReduceOp.MEAN
      ctx.set_last_step_output(label, output, reduce_op)

    return combined_fn.updates_op
예제 #7
0
  def step_fn(ctx, inputs):
    """Clones the model and calls make_fit_function."""
    inputs, targets = inputs
    if model._compile_distribution:
      clone_model_on_replicas(model, current_strategy,
                              make_callback_model=True, inputs=inputs,
                              targets=targets, mode=ModeKeys.TRAIN)
    else:
      _build_distributed_network(model, current_strategy, inputs,
                                 targets, mode=ModeKeys.TRAIN)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.extended.call_for_each_replica(
         _per_device_fit_function, args=(model._distributed_model_train,))
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs,
         grouped_updates, grouped_session_args)
    combined_fn = K.function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_fit_function',
        **all_session_args)

    for label, output in zip(out_labels, combined_fn.outputs):
      if label == 'loss':
        reduce_op = distribute_lib.get_loss_reduction()
      else:
        # We reduce all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        reduce_op = ds_reduce_util.ReduceOp.MEAN
      ctx.set_last_step_output(label, output, reduce_op)

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op
예제 #8
0
 def _scale_loss(loss_value):
   if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
     num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
     if num_replicas > 1:
       loss_value *= (1. / num_replicas)
   return loss_value
def unwrap_values(distribution_strategy,
                  grouped_inputs,
                  grouped_outputs,
                  grouped_updates=None,
                  grouped_session_args=None,
                  with_loss_tensor=False):
    """Unwrap and return the list of values contained in the PerDevice parameters.

  This function calls `flatten_perdevice_values` to parse each of the input
  parameters into a list of values on the different devices. If we set
  `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
  the different devices to give us one loss tensor.

  Args:
    distribution_strategy: DistributionStrategy used to distribute training and
        validation.
    grouped_inputs: PerDevice inputs returned from the train or test function
        that we ran on each device.
    grouped_outputs: PerDevice outputs returned from the train or test function
        that we ran on each device.
    grouped_updates: PerDevice updates returned from the train or test function
        that we ran on each device.
    grouped_session_args: PerDevice session args returned from the train or
        test function that we ran on each device.
    with_loss_tensor: Boolean that indicates if we need to add the reduced loss
        tensor as one of the outputs.

  Returns:
    Values of each of the PerDevice parameters.

  """
    # Unwrap per device values returned from each model's train function.
    # This will be used to construct the main train function.
    all_inputs = flatten_perdevice_values(distribution_strategy,
                                          grouped_inputs)
    if with_loss_tensor:
        # reduce loss tensor before adding it to the list of fetches
        loss = distribution_strategy.reduce(
            distribute_lib.get_loss_reduction(), grouped_outputs[0])
        all_outputs = flatten_perdevice_values(distribution_strategy,
                                               grouped_outputs[1:])
        all_outputs = [loss] + all_outputs
    else:
        all_outputs = flatten_perdevice_values(distribution_strategy,
                                               grouped_outputs)

    if grouped_updates:
        all_updates = flatten_perdevice_values(distribution_strategy,
                                               grouped_updates)
    else:
        all_updates = None

    all_session_args = {}
    if grouped_session_args:
        grouped_feed_dict = grouped_session_args.get('feed_dict')
        if grouped_feed_dict:
            all_session_args['feed_dict'] = flatten_perdevice_values(
                distribution_strategy, grouped_feed_dict)

        grouped_fetches = grouped_session_args.get('fetches')
        if grouped_fetches:
            all_session_args['fetches'] = flatten_perdevice_values(
                distribution_strategy, grouped_fetches)

    # TODO(priyag): Return only non empty/None values
    return all_inputs, all_outputs, all_updates, all_session_args
def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
                  grouped_updates=None, grouped_session_args=None,
                  with_loss_tensor=False):
  """Unwrap and return the list of values contained in the PerDevice parameters.

  This function calls `flatten_perdevice_values` to parse each of the input
  parameters into a list of values on the different devices. If we set
  `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
  the different devices to give us one loss tensor.

  Args:
    distribution_strategy: DistributionStrategy used to distribute training and
        validation.
    grouped_inputs: PerDevice inputs returned from the train or test function
        that we ran on each device.
    grouped_outputs: PerDevice outputs returned from the train or test function
        that we ran on each device.
    grouped_updates: PerDevice updates returned from the train or test function
        that we ran on each device.
    grouped_session_args: PerDevice session args returned from the train or
        test function that we ran on each device.
    with_loss_tensor: Boolean that indicates if we need to add the reduced loss
        tensor as one of the outputs.

  Returns:
    Values of each of the PerDevice parameters.

  """
  # Unwrap per device values returned from each model's train function.
  # This will be used to construct the main train function.
  all_inputs = flatten_perdevice_values(distribution_strategy,
                                        grouped_inputs)
  if with_loss_tensor:
    # reduce loss tensor before adding it to the list of fetches
    loss = distribution_strategy.reduce(distribute_lib.get_loss_reduction(),
                                        grouped_outputs[0])
    all_outputs = flatten_perdevice_values(distribution_strategy,
                                           grouped_outputs[1:])
    all_outputs = [loss] + all_outputs
  else:
    all_outputs = flatten_perdevice_values(distribution_strategy,
                                           grouped_outputs)

  if grouped_updates:
    all_updates = flatten_perdevice_values(distribution_strategy,
                                           grouped_updates)
  else:
    all_updates = None

  all_session_args = {}
  if grouped_session_args:
    grouped_feed_dict = grouped_session_args.get('feed_dict')
    if grouped_feed_dict:
      all_session_args['feed_dict'] = flatten_perdevice_values(
          distribution_strategy, grouped_feed_dict)

    grouped_fetches = grouped_session_args.get('fetches')
    if grouped_fetches:
      all_session_args['fetches'] = flatten_perdevice_values(
          distribution_strategy, grouped_fetches)

  # TODO(priyag): Return only non empty/None values
  return all_inputs, all_outputs, all_updates, all_session_args