Exemplo n.º 1
0
def _aggregate_predict_results(strategy, batch_outs, model):
    """Aggregate the prediction result from each replica."""
    num_replicas = strategy.num_replicas_in_sync
    num_outputs = len(model.outputs)

    if not isinstance(batch_outs, list):
        batch_outs = [batch_outs]

    with_batch_index = _should_add_batch_index_to_element(
        strategy, ModeKeys.PREDICT)

    # batch_outs is in following structure:
    # [
    #  replica_1_batch_index, replica_2_batch_index, ...., replica_x_batch_index,
    #  replica_1_output_1, replica_2_output_1, ...., replica_x_output_1,
    #  ......
    #  replica_1_output_y, replica_2_output_y, ...., replica_x_output_y,
    # ]
    # The replica_x_batch_index is optional and depended on teh strategy type.
    if with_batch_index:
        batch_index, batch_outs = (batch_outs[:num_replicas],
                                   batch_outs[num_replicas:])
        batch_index = dist_utils.concat_along_batch_dimension(batch_index)
        # Reorder the batch_index for it to do proper gather. Eg, if the original
        # index is [0, 2, 4, 6, 1, 3, 5, 7], then the index for gather should be
        # [0, 4, 1, 5, 2, 6, 3, 7].
        batch_index = np.argsort(batch_index)
        # Only need to gather if the batch index is not sorted.
        need_batch_index_gather = np.any(np.diff(batch_index) < 0)
    else:
        need_batch_index_gather = False

    total_batch_outs = []
    for i in range(num_outputs):
        nested_outs = batch_outs[i * num_replicas:i * num_replicas +
                                 num_replicas]
        per_output_result = dist_utils.concat_along_batch_dimension(
            nest.flatten(nested_outs))

        if need_batch_index_gather:
            if _get_batch_size(per_output_result).numpy() == len(batch_index):
                # Skip the gather if the output has a different batch size than the
                # batch_index. There will be some error handling in upper layer.
                per_output_result = _gather_result_by_index(
                    per_output_result, batch_index)
        total_batch_outs.append(per_output_result)
    return total_batch_outs
Exemplo n.º 2
0
def _aggregate_predict_results(strategy, batch_outs, model):
  if not isinstance(batch_outs, list):
    batch_outs = [batch_outs]
  total_batch_outs = []
  for i in range(len(model.outputs)):
    num_replicas = strategy.num_replicas_in_sync
    nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
    total_batch_outs.append(
        dist_utils.concat_along_batch_dimension(nest.flatten(nested_outs)))
  return total_batch_outs