def _aggregate_predict_results(strategy, batch_outs, model): """Aggregate the prediction result from each replica.""" num_replicas = strategy.num_replicas_in_sync num_outputs = len(model.outputs) if not isinstance(batch_outs, list): batch_outs = [batch_outs] with_batch_index = _should_add_batch_index_to_element( strategy, ModeKeys.PREDICT) # batch_outs is in following structure: # [ # replica_1_batch_index, replica_2_batch_index, ...., replica_x_batch_index, # replica_1_output_1, replica_2_output_1, ...., replica_x_output_1, # ...... # replica_1_output_y, replica_2_output_y, ...., replica_x_output_y, # ] # The replica_x_batch_index is optional and depended on teh strategy type. if with_batch_index: batch_index, batch_outs = (batch_outs[:num_replicas], batch_outs[num_replicas:]) batch_index = dist_utils.concat_along_batch_dimension(batch_index) # Reorder the batch_index for it to do proper gather. Eg, if the original # index is [0, 2, 4, 6, 1, 3, 5, 7], then the index for gather should be # [0, 4, 1, 5, 2, 6, 3, 7]. batch_index = np.argsort(batch_index) # Only need to gather if the batch index is not sorted. need_batch_index_gather = np.any(np.diff(batch_index) < 0) else: need_batch_index_gather = False total_batch_outs = [] for i in range(num_outputs): nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas] per_output_result = dist_utils.concat_along_batch_dimension( nest.flatten(nested_outs)) if need_batch_index_gather: if _get_batch_size(per_output_result).numpy() == len(batch_index): # Skip the gather if the output has a different batch size than the # batch_index. There will be some error handling in upper layer. per_output_result = _gather_result_by_index( per_output_result, batch_index) total_batch_outs.append(per_output_result) return total_batch_outs
def _aggregate_predict_results(strategy, batch_outs, model): if not isinstance(batch_outs, list): batch_outs = [batch_outs] total_batch_outs = [] for i in range(len(model.outputs)): num_replicas = strategy.num_replicas_in_sync nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas] total_batch_outs.append( dist_utils.concat_along_batch_dimension(nest.flatten(nested_outs))) return total_batch_outs