def step_fn(ctx, inputs):
    """Clones the model and calls make_predict_function."""
    if model._compile_distribution:
      distributed_training_utils.clone_model_on_replicas(
          model, current_strategy, mode, inputs=inputs)
    else:
      distributed_training_utils._build_distributed_network(
          model, current_strategy, mode, inputs)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.extended.call_for_each_replica(
         _per_device_predict_function,
         args=(distributed_training_utils.get_distributed_model(
             model, ModeKeys.PREDICT),))

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)

    combined_fn = K.function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_predict_function',
        **all_session_args)

    for label, output in zip(model.output_names, combined_fn.outputs):
      ctx.set_last_step_output(label, output)

    return combined_fn.updates_op
    def step_fn(ctx, inputs):
        """Clones the model and calls make_predict_function."""
        if model._compile_distribution:
            distributed_training_utils.clone_model_on_replicas(
                model, current_strategy, mode, inputs=inputs)
        else:
            distributed_training_utils._build_distributed_network(
                model, current_strategy, mode, inputs)

        (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args
         ) = current_strategy.extended.call_for_each_replica(
             _per_device_predict_function,
             args=(distributed_training_utils.get_distributed_model(
                 model, ModeKeys.PREDICT), ))

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)

        combined_fn = K.function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_predict_function',
                                 **all_session_args)

        for label, output in zip(model.output_names, combined_fn.outputs):
            ctx.set_last_step_output(label, output)

        return combined_fn.updates_op
Beispiel #3
0
    def _step_fn(ctx, inputs):
        """A step fn that returns update ops."""
        inputs, targets = inputs
        _build_model(strategy, model, mode, inputs, targets)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = strategy.extended.call_for_each_replica(
             _per_device_execution_function,
             args=(distributed_training_utils.get_distributed_model(
                 model, mode), mode))
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             strategy, grouped_inputs, grouped_outputs, grouped_updates,
             grouped_session_args)
        combined_fn = K.function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_' + str(mode) + '_function',
                                 **all_session_args)

        for label, output in zip(output_labels, combined_fn.outputs):
            if label == 'loss':
                reduce_op = ds_reduce_util.ReduceOp.SUM
            else:
                # We reduce all other metrics using mean for now. This is temporary
                # workaround until new metrics are in place.
                reduce_op = ds_reduce_util.ReduceOp.MEAN
            ctx.set_last_step_output(label, output, reduce_op)

        # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
        # feed_dict, session kwargs, run options, run_metadata for now. These should
        # be handled appropriately
        return combined_fn.updates_op
Beispiel #4
0
    def testOptimizerWithCallbacks(self, distribution):
        with self.cached_session():
            model = get_model()

            optimizer = gradient_descent_keras.SGD(0.01)
            loss = 'mse'
            model.compile(optimizer, loss, distribute=distribution)

            dataset = get_dataset(distribution)

            def schedule(_):
                return 0.001

            model.fit(
                dataset,
                epochs=1,
                steps_per_epoch=2,
                verbose=0,
                callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
            grouped_models = distribution.unwrap(
                distributed_training_utils.get_distributed_model(
                    model, ModeKeys.TRAIN))
            with distribution.scope():
                for m in grouped_models:
                    self.assertAllClose(0.001,
                                        keras.backend.get_value(
                                            m.optimizer.lr),
                                        atol=1e-05,
                                        rtol=1e-05)
Beispiel #5
0
    def _test_step_fn(inputs):
        """A fn that returns output of single test step."""
        inputs, targets = inputs
        (distribution_strategy_context.get_replica_context().merge_call(
            _build_model, args=(model, mode, inputs, targets)))

        (_, outputs, updates, _) = (_per_device_execution_function(
            distributed_training_utils.get_distributed_model(model, mode),
            mode))
        with ops.control_dependencies([updates]):
            return outputs
  def step_fn(ctx, inputs):
    """A step fn that returns update ops."""
    if mode == ModeKeys.PREDICT:
      targets = None
    else:
      inputs, targets = inputs

    if model._compile_distribution:
      distributed_training_utils.clone_model_on_replicas(
          model, strategy, mode, inputs=inputs, targets=targets)
    else:
      distributed_training_utils._build_distributed_network(
          model, strategy, mode, inputs, targets)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = strategy.extended.call_for_each_replica(
         _per_device_execution_function,
         args=(distributed_training_utils.get_distributed_model(model, mode),))
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)
    combined_fn = K.function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_' + str(mode) + '_function',
        **all_session_args)

    for label, output in zip(output_labels, combined_fn.outputs):
      if mode == ModeKeys.PREDICT:
        ctx.set_last_step_output(label, output)
      else:
        if label == 'loss':
          reduce_op = ds_reduce_util.ReduceOp.SUM
        else:
          # We reduce all other metrics using mean for now. This is temporary
          # workaround until new metrics are in place.
          reduce_op = ds_reduce_util.ReduceOp.MEAN
        ctx.set_last_step_output(label, output, reduce_op)

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op
    def step_fn(ctx, inputs):
        """Clones the model and calls make_eval_function."""
        inputs, targets = inputs
        if model._compile_distribution:
            distributed_training_utils.clone_model_on_replicas(
                model,
                current_strategy,
                mode=mode,
                inputs=inputs,
                targets=targets)
        else:
            distributed_training_utils._build_distributed_network(
                model, current_strategy, mode, inputs, targets)

        (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args
         ) = current_strategy.extended.call_for_each_replica(
             _per_device_eval_function,
             args=(distributed_training_utils.get_distributed_model(
                 model, ModeKeys.TEST), ))

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)

        combined_fn = K.function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_test_function',
                                 **all_session_args)

        for label, output in zip(model.metrics_names, combined_fn.outputs):
            if label == 'loss':
                reduce_op = ds_reduce_util.ReduceOp.SUM
            else:
                # We reduce all other metrics using mean for now. This is temporary
                # workaround until new metrics are in place.
                reduce_op = ds_reduce_util.ReduceOp.MEAN
            ctx.set_last_step_output(label, output, reduce_op)

        return combined_fn.updates_op
  def testOptimizerWithCallbacks(self, distribution):
    with self.cached_session():
      model = get_model()

      optimizer = gradient_descent_keras.SGD(0.01)
      loss = 'mse'
      model.compile(optimizer, loss, distribute=distribution)

      dataset = get_dataset(distribution)

      def schedule(_):
        return 0.001

      model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
                callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
      grouped_models = distribution.unwrap(
          distributed_training_utils.get_distributed_model(
              model, ModeKeys.TRAIN))
      with distribution.scope():
        for m in grouped_models:
          self.assertAllClose(0.001, keras.backend.get_value(
              m.optimizer.lr), atol=1e-05, rtol=1e-05)
  def step_fn(ctx, inputs):
    """Clones the model and calls make_fit_function."""
    inputs, targets = inputs
    if model._compile_distribution:
      distributed_training_utils.clone_model_on_replicas(
          model, current_strategy, mode, inputs=inputs, targets=targets)
    else:
      distributed_training_utils._build_distributed_network(
          model, current_strategy, mode, inputs, targets)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.extended.call_for_each_replica(
         _per_device_fit_function,
         args=(distributed_training_utils.get_distributed_model(
             model, ModeKeys.TRAIN),))
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs,
         grouped_updates, grouped_session_args)
    combined_fn = K.function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_fit_function',
        **all_session_args)

    for label, output in zip(out_labels, combined_fn.outputs):
      if label == 'loss':
        reduce_op = ds_reduce_util.ReduceOp.SUM
      else:
        # We reduce all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        reduce_op = ds_reduce_util.ReduceOp.MEAN
      ctx.set_last_step_output(label, output, reduce_op)

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op
  def step_fn(ctx, inputs):
    """Clones the model and calls make_eval_function."""
    inputs, targets = inputs
    if model._compile_distribution:
      distributed_training_utils.clone_model_on_replicas(
          model, current_strategy, mode=mode, inputs=inputs, targets=targets)
    else:
      distributed_training_utils._build_distributed_network(
          model, current_strategy, mode, inputs, targets)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.extended.call_for_each_replica(
         _per_device_eval_function,
         args=(distributed_training_utils.get_distributed_model(
             model, ModeKeys.TEST),))

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)

    combined_fn = K.function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_test_function',
        **all_session_args)

    for label, output in zip(model.metrics_names, combined_fn.outputs):
      if label == 'loss':
        reduce_op = ds_reduce_util.ReduceOp.SUM
      else:
        # We reduce all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        reduce_op = ds_reduce_util.ReduceOp.MEAN
      ctx.set_last_step_output(label, output, reduce_op)

    return combined_fn.updates_op