def predict_loop(batch): batch_result = strategy.experimental_run_v2( execution_function, batch) batch_result = dist_utils.unwrap_output_dict( strategy, batch_result, ModeKeys.PREDICT) # swap the order of replica 1 and 2, to mimic random order. batch_result[2], batch_result[1] = batch_result[1], batch_result[2] batch_result[5], batch_result[4] = batch_result[4], batch_result[5] return batch_result
def distributed_function(input_iterator): """A single step of the distributed execution across replicas.""" # Call `Model.{train,test,predict}_on_batch` on every replica passing # PerReplicas as arguments. On every replica inside this call, each # PerReplica object will return the value for that replica. The outputs # are PerReplicas too. strategy = distribution_strategy_context.get_strategy() args = _prepare_feed_values(model, input_iterator, mode, strategy) outputs = strategy.experimental_run_v2(per_replica_function, args=args) # Out of PerReplica outputs reduce or pick values to return. all_outputs = dist_utils.unwrap_output_dict(strategy, outputs, mode) return all_outputs