Exemplo n.º 1
0
    def step_fn(ctx, inputs):
        """Clones the model and calls make_predict_function."""
        if model._compile_distribution:
            distributed_training_utils.clone_model_on_replicas(
                model, current_strategy, ModeKeys.PREDICT, inputs=inputs)
        else:
            distributed_training_utils._build_distributed_network(
                model, current_strategy, ModeKeys.PREDICT, inputs)

        (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args
         ) = current_strategy.extended.call_for_each_replica(
             _per_device_predict_function,
             args=(model._distributed_model_predict, ))

        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)

        combined_fn = K.function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_predict_function',
                                 **all_session_args)

        for label, output in zip(model.output_names, combined_fn.outputs):
            ctx.set_last_step_output(label, output)

        return combined_fn.updates_op
Exemplo n.º 2
0
  def step_fn(ctx, inputs):
    """Clones the model and calls make_predict_function."""
    if model._compile_distribution:
      distributed_training_utils.clone_model_on_replicas(
          model, current_strategy, mode, inputs=inputs)
    else:
      distributed_training_utils._build_distributed_network(
          model, current_strategy, mode, inputs)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.extended.call_for_each_replica(
         _per_device_predict_function,
         args=(distributed_training_utils.get_distributed_model(
             model, ModeKeys.PREDICT),))

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)

    combined_fn = K.function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_predict_function',
        **all_session_args)

    for label, output in zip(model.output_names, combined_fn.outputs):
      ctx.set_last_step_output(label, output)

    return combined_fn.updates_op
Exemplo n.º 3
0
  def step_fn(ctx, inputs):
    """Clones the model and calls make_eval_function."""
    inputs, targets = inputs
    if model._compile_distribution:
      distributed_training_utils.clone_model_on_replicas(
          model, current_strategy, mode=mode, inputs=inputs, targets=targets)
    else:
      distributed_training_utils._build_distributed_network(
          model, current_strategy, mode, inputs, targets)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.extended.call_for_each_replica(
         _per_device_eval_function, args=(model._distributed_model_test,))

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)

    combined_fn = K.function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_test_function',
        **all_session_args)

    for label, output in zip(model.metrics_names, combined_fn.outputs):
      if label == 'loss':
        reduce_op = distribute_lib.get_loss_reduction()
      else:
        # We reduce all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        reduce_op = ds_reduce_util.ReduceOp.MEAN
      ctx.set_last_step_output(label, output, reduce_op)

    return combined_fn.updates_op
Exemplo n.º 4
0
def _build_model(strategy, model, mode, inputs, targets=None):
    if model._compile_distribution:
        distributed_training_utils.clone_model_on_replicas(model,
                                                           strategy,
                                                           mode,
                                                           inputs=inputs,
                                                           targets=targets)
    else:
        distributed_training_utils._build_distributed_network(
            model, strategy, mode, inputs, targets)
Exemplo n.º 5
0
    def step_fn(ctx, inputs):
        """A step fn that returns update ops."""
        if mode == ModeKeys.PREDICT:
            targets = None
        else:
            inputs, targets = inputs

        if model._compile_distribution:
            distributed_training_utils.clone_model_on_replicas(model,
                                                               strategy,
                                                               mode,
                                                               inputs=inputs,
                                                               targets=targets)
        else:
            distributed_training_utils._build_distributed_network(
                model, strategy, mode, inputs, targets)

        (grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args) = strategy.extended.call_for_each_replica(
             _per_device_execution_function,
             args=(distributed_training_utils.get_distributed_model(
                 model, mode), ))
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             strategy, grouped_inputs, grouped_outputs, grouped_updates,
             grouped_session_args)
        combined_fn = K.function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_' + str(mode) + '_function',
                                 **all_session_args)

        for label, output in zip(output_labels, combined_fn.outputs):
            if mode == ModeKeys.PREDICT:
                ctx.set_last_step_output(label, output)
            else:
                if label == 'loss':
                    reduce_op = ds_reduce_util.ReduceOp.SUM
                else:
                    # We reduce all other metrics using mean for now. This is temporary
                    # workaround until new metrics are in place.
                    reduce_op = ds_reduce_util.ReduceOp.MEAN
                ctx.set_last_step_output(label, output, reduce_op)

        # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
        # feed_dict, session kwargs, run options, run_metadata for now. These should
        # be handled appropriately
        return combined_fn.updates_op
Exemplo n.º 6
0
    def step_fn(ctx, inputs):
        """Clones the model and calls make_fit_function."""
        inputs, targets = inputs
        if model._compile_distribution:
            distributed_training_utils.clone_model_on_replicas(
                model,
                current_strategy,
                make_callback_model=True,
                inputs=inputs,
                targets=targets,
                mode=distributed_training_utils.ModeKeys.TRAIN)
        else:
            distributed_training_utils._build_distributed_network(
                model,
                current_strategy,
                inputs,
                targets,
                mode=distributed_training_utils.ModeKeys.TRAIN)

        (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args
         ) = current_strategy.extended.call_for_each_replica(
             _per_device_fit_function, args=(model._distributed_model_train, ))
        (all_inputs, all_outputs, all_updates,
         all_session_args) = distributed_training_utils.unwrap_values(
             current_strategy, grouped_inputs, grouped_outputs,
             grouped_updates, grouped_session_args)
        combined_fn = K.function(all_inputs,
                                 all_outputs,
                                 updates=all_updates,
                                 name='distributed_fit_function',
                                 **all_session_args)

        for label, output in zip(out_labels, combined_fn.outputs):
            if label == 'loss':
                reduce_op = distribute_lib.get_loss_reduction()
            else:
                # We reduce all other metrics using mean for now. This is temporary
                # workaround until new metrics are in place.
                reduce_op = ds_reduce_util.ReduceOp.MEAN
            ctx.set_last_step_output(label, output, reduce_op)

        # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
        # feed_dict, session kwargs, run options, run_metadata for now. These should
        # be handled appropriately
        return combined_fn.updates_op
Exemplo n.º 7
0
  def step_fn(ctx, inputs):
    """A step fn that returns update ops."""
    if mode == ModeKeys.PREDICT:
      targets = None
    else:
      inputs, targets = inputs

    if model._compile_distribution:
      distributed_training_utils.clone_model_on_replicas(
          model, strategy, mode, inputs=inputs, targets=targets)
    else:
      distributed_training_utils._build_distributed_network(
          model, strategy, mode, inputs, targets)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = strategy.extended.call_for_each_replica(
         _per_device_execution_function,
         args=(distributed_training_utils.get_distributed_model(model, mode),))
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)
    combined_fn = K.function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_' + str(mode) + '_function',
        **all_session_args)

    for label, output in zip(output_labels, combined_fn.outputs):
      if mode == ModeKeys.PREDICT:
        ctx.set_last_step_output(label, output)
      else:
        if label == 'loss':
          reduce_op = ds_reduce_util.ReduceOp.SUM
        else:
          # We reduce all other metrics using mean for now. This is temporary
          # workaround until new metrics are in place.
          reduce_op = ds_reduce_util.ReduceOp.MEAN
        ctx.set_last_step_output(label, output, reduce_op)

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op
Exemplo n.º 8
0
  def step_fn(ctx, inputs):
    """Clones the model and calls make_fit_function."""
    inputs, targets = inputs
    if model._compile_distribution:
      distributed_training_utils.clone_model_on_replicas(
          model, current_strategy,
          make_callback_model=True, inputs=inputs,
          targets=targets, mode=distributed_training_utils.ModeKeys.TRAIN)
    else:
      distributed_training_utils._build_distributed_network(
          model, current_strategy, inputs,
          targets, mode=distributed_training_utils.ModeKeys.TRAIN)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.extended.call_for_each_replica(
         _per_device_fit_function, args=(model._distributed_model_train,))
    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs,
         grouped_updates, grouped_session_args)
    combined_fn = K.function(
        all_inputs,
        all_outputs,
        updates=all_updates,
        name='distributed_fit_function',
        **all_session_args)

    for label, output in zip(out_labels, combined_fn.outputs):
      if label == 'loss':
        reduce_op = distribute_lib.get_loss_reduction()
      else:
        # We reduce all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        reduce_op = ds_reduce_util.ReduceOp.MEAN
      ctx.set_last_step_output(label, output, reduce_op)

    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
    # feed_dict, session kwargs, run options, run_metadata for now. These should
    # be handled appropriately
    return combined_fn.updates_op
Exemplo n.º 9
0
  def step_fn(ctx, inputs):
    """Clones the model and calls make_eval_function."""
    inputs, targets = inputs
    if model._compile_distribution:
      distributed_training_utils. clone_model_on_replicas(
          model, current_strategy,
          make_callback_model=False, inputs=inputs,
          targets=targets, mode=distributed_training_utils.ModeKeys.TEST)
    else:
      distributed_training_utils._build_distributed_network(
          model, current_strategy, inputs, targets,
          mode=distributed_training_utils.ModeKeys.TEST)

    (grouped_inputs, grouped_outputs, grouped_updates,
     grouped_session_args) = current_strategy.extended.call_for_each_replica(
         _per_device_eval_function, args=(model._distributed_model_test,))

    (all_inputs, all_outputs, all_updates,
     all_session_args) = distributed_training_utils.unwrap_values(
         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
         grouped_session_args)

    combined_fn = K.function(
        all_inputs, all_outputs,
        updates=all_updates,
        name='distributed_test_function',
        **all_session_args)

    for label, output in zip(model.metrics_names, combined_fn.outputs):
      if label == 'loss':
        reduce_op = distribute_lib.get_loss_reduction()
      else:
        # We reduce all other metrics using mean for now. This is temporary
        # workaround until new metrics are in place.
        reduce_op = ds_reduce_util.ReduceOp.MEAN
      ctx.set_last_step_output(label, output, reduce_op)

    return combined_fn.updates_op