예제 #1
0
  def _step_fn(ctx, inputs):
    """A step fn that returns update ops."""
    inputs, targets = inputs
    _build_model(strategy, model, mode, inputs, targets)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = strategy.extended.call_for_each_replica(
         _per_device_execution_function,
         args=(distributed_training_utils.get_distributed_model(model, mode),
               mode))
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)
    combined_fn = K.function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_' + str(mode) + '_function',
        **all_session_args)

    for label, output in zip(output_labels, combined_fn.outputs):
      if label == 'loss':
        reduce_op = ds_reduce_util.ReduceOp.SUM
      else:
        # We reduce all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        reduce_op = ds_reduce_util.ReduceOp.MEAN
      ctx.set_last_step_output(label, output, reduce_op)

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op
예제 #2
0
    def _get_distributed_model(self, mode):
        # not available in tf 1.13, code shouldn't reach here for 1.13
        # because of _is_not_supported
        from tensorflow.python.keras.distribute.distributed_training_utils import (
            get_distributed_model,
        )

        return get_distributed_model(self.model, get_keras_mode(mode))
예제 #3
0
    def _predict_step_fn(inputs):
        """A fn that returns output of single prediction step."""

        (distribution_strategy_context.get_replica_context().merge_call(
            _build_model, args=(model, mode, inputs)))

        (_, outputs, updates, _) = _per_replica_execution_function(
            dist_utils.get_distributed_model(model, mode), mode)

        with ops.control_dependencies([updates]):
            return [array_ops.identity(out) for out in outputs]
예제 #4
0
    def _test_step_fn(inputs):
        """A fn that returns output of single test step."""
        inputs, targets = inputs
        (distribution_strategy_context.get_replica_context().merge_call(
            _build_model, args=(model, mode, inputs, targets)))

        (_, outputs, updates, _) = (_per_device_execution_function(
            distributed_training_utils.get_distributed_model(model, mode),
            mode))
        with ops.control_dependencies([updates]):
            return outputs
예제 #5
0
  def _test_step_fn(inputs):
    """A fn that returns output of single test step."""
    inputs, targets = inputs
    (distribution_strategy_context.get_replica_context().merge_call(
        _build_model, args=(model, mode, inputs, targets)))

    (_, outputs, updates, _) = (
        _per_device_execution_function(
            distributed_training_utils.get_distributed_model(model, mode),
            mode))
    with ops.control_dependencies([updates]):
      return outputs
    def on_train_batch_end(self, batch, logs):
        if self.fetches_added:
            # these should only be added if these were available above
            from tensorflow.python.keras.distribute.distributed_training_utils import (
                get_distributed_model, )
            from tensorflow.python.keras.utils.mode_keys import ModeKeys as KerasModeKeys

            for t in self.tensors:
                x = get_distributed_model(
                    self.model, KerasModeKeys.TRAIN)._distributed_function
                x.fetches.remove(t)
                del x.fetch_callbacks[t]
            self.fetches_added = False
예제 #7
0
    def _test_step_fn(inputs):
        """A fn that returns output of single test step."""
        if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
            inputs, targets = inputs
        else:
            targets = None

        (distribution_strategy_context.get_replica_context().merge_call(
            _build_model, args=(model, mode, inputs, targets)))

        (_, outputs, updates, _) = _per_replica_execution_function(
            dist_utils.get_distributed_model(model, mode), mode)
        with ops.control_dependencies([updates]):
            return [array_ops.identity(out) for out in outputs]
    def on_train_batch_begin(self, batch, logs):
        try:
            from tensorflow.python.keras.distribute.distributed_training_utils import (
                get_distributed_model, )
            from tensorflow.python.keras.utils.mode_keys import ModeKeys as KerasModeKeys

            for t in self.tensors:
                x = get_distributed_model(
                    self.model, KerasModeKeys.TRAIN)._distributed_function
                x.fetches.append(t)
                x.fetch_callbacks[t] = self._callback_fn
            self.fetches_added = True
        except ImportError:
            pass
예제 #9
0
  def _test_step_fn(inputs):
    """A fn that returns output of single test step."""
    if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
      inputs, targets = inputs
    else:
      targets = None

    (distribution_strategy_context.get_replica_context().merge_call(
        _build_model, args=(model, mode, inputs, targets)))

    (_, outputs, updates, _) = (
        _per_replica_execution_function(
            distributed_training_utils.get_distributed_model(model, mode),
            mode))
    with ops.control_dependencies([updates]):
      return outputs
예제 #10
0
    def _step_fn(ctx, inputs):
        """A step fn that returns update ops."""
        if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
            inputs, targets = inputs
        else:
            targets = None

        # When input feature is a dictionary of tensors, dictionary is flattended
        # to an array and passed as a model input. This results in input mismatch
        # when model input layer names are not sorted in alphabetical order as
        # `nest.flatten()`sorts dictioary elements by keys. As so, transform input
        # tensors into an array and order it along `model._feed_input_names`.
        if isinstance(inputs, dict):
            inputs = [
                inputs[input_name] for input_name in model._feed_input_names
            ]

        _build_model(strategy, model, mode, inputs, targets)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = strategy.extended.call_for_each_replica(
             _per_replica_execution_function,
             args=(dist_utils.get_distributed_model(model, mode), mode))
        (all_inputs, all_outputs, all_updates,
         all_session_args) = dist_utils.unwrap_values(strategy, grouped_inputs,
                                                      grouped_outputs,
                                                      grouped_updates,
                                                      grouped_session_args)
        combined_fn = K.function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_' + str(mode) + '_function',
                                 **all_session_args)

        for label, output in zip(output_labels, combined_fn.outputs):
            if label == 'loss':
                reduce_op = ds_reduce_util.ReduceOp.SUM
            else:
                # We reduce all other metrics using mean for now. This is temporary
                # workaround until new metrics are in place.
                reduce_op = ds_reduce_util.ReduceOp.MEAN
            ctx.set_last_step_output(label, output, reduce_op)

        # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
        # feed_dict, session kwargs, run options, run_metadata for now. These should
        # be handled appropriately
        return combined_fn.updates_op
예제 #11
0
  def _step_fn(ctx, inputs):
    """A step fn that returns update ops."""
    if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
      inputs, targets = inputs
    else:
      targets = None

    # When input feature is a dictionary of tensors, dictionary is flattended
    # to an array and passed as a model input. This results in input mismatch
    # when model input layer names are not sorted in alphabetical order as
    # `nest.flatten()`sorts dictioary elements by keys. As so, transform input
    # tensors into an array and order it along `model._feed_input_names`.
    if isinstance(inputs, dict):
      inputs = [inputs[input_name] for input_name in model._feed_input_names]

    _build_model(strategy, model, mode, inputs, targets)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = strategy.extended.call_for_each_replica(
         _per_replica_execution_function,
         args=(distributed_training_utils.get_distributed_model(model, mode),
               mode))
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)
    combined_fn = K.function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_' + str(mode) + '_function',
        **all_session_args)

    for label, output in zip(output_labels, combined_fn.outputs):
      if label == 'loss':
        reduce_op = ds_reduce_util.ReduceOp.SUM
      else:
        # We reduce all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        reduce_op = ds_reduce_util.ReduceOp.MEAN
      ctx.set_last_step_output(label, output, reduce_op)

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op
  def testOptimizerWithCallbacks(self, distribution):
    with self.cached_session():
      model = get_model()

      optimizer = gradient_descent_keras.SGD(0.01)
      loss = 'mse'
      model.compile(optimizer, loss, distribute=distribution)

      dataset = get_dataset(distribution)

      def schedule(_):
        return 0.001

      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
                callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
      grouped_models = distribution.experimental_local_results(
          distributed_training_utils.get_distributed_model(
              model, ModeKeys.TRAIN))
      with distribution.scope():
        for m in grouped_models:
          self.assertAllClose(0.001, keras.backend.get_value(
              m.optimizer.lr), atol=1e-05, rtol=1e-05)
예제 #13
0
  def testOptimizerWithCallbacks(self, distribution):
    with self.cached_session():
      model = get_model()

      optimizer = gradient_descent_keras.SGD(0.01)
      loss = 'mse'
      model.compile(optimizer, loss, distribute=distribution)

      dataset = get_dataset(distribution)

      def schedule(_):
        return 0.001

      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
                callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
      grouped_models = distribution.experimental_local_results(
          distributed_training_utils.get_distributed_model(
              model, ModeKeys.TRAIN))
      with distribution.scope():
        for m in grouped_models:
          self.assertAllClose(0.001, keras.backend.get_value(
              m.optimizer.lr), atol=1e-05, rtol=1e-05)