def predict_loop(batch):
     batch_result = strategy.experimental_run_v2(
         execution_function, batch)
     batch_result = dist_utils.unwrap_output_dict(
         strategy, batch_result, ModeKeys.PREDICT)
     # swap the order of replica 1 and 2, to mimic random order.
     batch_result[2], batch_result[1] = batch_result[1], batch_result[2]
     batch_result[5], batch_result[4] = batch_result[4], batch_result[5]
     return batch_result
Ejemplo n.º 2
0
 def distributed_function(input_iterator):
     """A single step of the distributed execution across replicas."""
     # Call `Model.{train,test,predict}_on_batch` on every replica passing
     # PerReplicas as arguments.  On every replica inside this call, each
     # PerReplica object will return the value for that replica.  The outputs
     # are PerReplicas too.
     strategy = distribution_strategy_context.get_strategy()
     args = _prepare_feed_values(model, input_iterator, mode, strategy)
     outputs = strategy.experimental_run_v2(per_replica_function, args=args)
     # Out of PerReplica outputs reduce or pick values to return.
     all_outputs = dist_utils.unwrap_output_dict(strategy, outputs, mode)
     return all_outputs