Exemplo n.º 1
0
def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
                   aggregation):
  # pylint: disable=g-missing-docstring
  all_values = []
  count = 0
  for v in per_device_value._index.values():  # pylint: disable=protected-access
    if isinstance(v, value_lib.MapOutput):
      v_list = v.get()
      if not v_list:
        continue
      count += len(v_list)
      # Sum within each device before aggregating across devices.
      # TODO(yuefengz): Check whether it helps to use accumulation_fn here.
      v = cross_tower_utils.aggregate_tensors_or_indexed_slices(
          v_list, math_ops.add_n)
    else:
      count += 1
    all_values.append(v)
  if not all_values:
    raise ValueError("`per_device_value` must be non-empty")

  with ops.device(reduce_to_device):
    with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
      reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
          all_values, accumulation_fn)
      if aggregation == vs.VariableAggregation.MEAN:
        reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
            reduced, count)
      elif aggregation != vs.VariableAggregation.SUM:
        raise ValueError("`aggregation` must be VariableAggregation.SUM "
                         "or VariableAggregation.MEAN.")
  return reduced
Exemplo n.º 2
0
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
                   aggregation):
  # pylint: disable=g-missing-docstring
  all_values = []
  count = 0
  for v in per_replica_value._index.values():  # pylint: disable=protected-access
    if isinstance(v, value_lib.MapOutput):
      v_list = v.get()
      if not v_list:
        continue
      count += len(v_list)
      # Sum within each device before aggregating across devices.
      # TODO(yuefengz): Check whether it helps to use accumulation_fn here.
      v = cross_tower_utils.aggregate_tensors_or_indexed_slices(
          v_list, math_ops.add_n)
    else:
      count += 1
    all_values.append(v)
  if not all_values:
    raise ValueError("`per_replica_value` must be non-empty")

  with ops.device(reduce_to_device):
    with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
      reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
          all_values, accumulation_fn)
      if aggregation == vs.VariableAggregation.MEAN:
        reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
            reduced, count)
      elif aggregation != vs.VariableAggregation.SUM:
        raise ValueError("`aggregation` must be VariableAggregation.SUM "
                         "or VariableAggregation.MEAN.")
  return reduced
 def testAggregateTensors(self):
     t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
     t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]])
     total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
     result = cross_tower_utils.aggregate_tensors_or_indexed_slices(
         [t0, t1])
     self._assert_values_equal(total, result)
 def testAggregateIndexedSlices(self):
   t0 = math_ops._as_indexed_slices(
       constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
   t1 = math_ops._as_indexed_slices(
       constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
   total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
   result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
   self.assertIsInstance(result, ops.IndexedSlices)
   self._assert_values_equal(total, result)
 def testAggregateIndexedSlices(self):
   t0 = math_ops._as_indexed_slices(
       constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
   t1 = math_ops._as_indexed_slices(
       constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
   total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
   result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
   self.assertIsInstance(result, ops.IndexedSlices)
   self._assert_values_equal(total, result)
Exemplo n.º 6
0
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
                   reduce_op):
    # pylint: disable=g-missing-docstring
    all_values = []
    count = 0
    for v in per_replica_value._index.values():  # pylint: disable=protected-access
        count += 1
        all_values.append(v)
    if not all_values:
        raise ValueError("`per_replica_value` must be non-empty")

    with ops.device(reduce_to_device):
        with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
            reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
                all_values, accumulation_fn)
            if reduce_op == reduce_util.ReduceOp.MEAN:
                reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
                    reduced, count)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise ValueError(
                    "`reduce_op` must be Reduce.SUM or Reduce.MEAN.")
    return reduced
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
                   aggregation):
  # pylint: disable=g-missing-docstring
  all_values = []
  count = 0
  for v in per_replica_value._index.values():  # pylint: disable=protected-access
    count += 1
    all_values.append(v)
  if not all_values:
    raise ValueError("`per_replica_value` must be non-empty")

  with ops.device(reduce_to_device):
    with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
      reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
          all_values, accumulation_fn)
      if aggregation == vs.VariableAggregation.MEAN:
        reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
            reduced, count)
      elif aggregation != vs.VariableAggregation.SUM:
        raise ValueError("`aggregation` must be VariableAggregation.SUM "
                         "or VariableAggregation.MEAN.")
  return reduced
 def testAggregateTensors(self):
   t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
   t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]])
   total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]])
   result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
   self._assert_values_equal(total, result)