예제 #1
0
def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
                   aggregation):
  # pylint: disable=g-missing-docstring
  all_values = []
  count = 0
  for v in per_device_value._index.values():  # pylint: disable=protected-access
    if isinstance(v, value_lib.MapOutput):
      v_list = v.get()
      if not v_list:
        continue
      count += len(v_list)
      # Sum within each device before aggregating across devices.
      # TODO(yuefengz): Check whether it helps to use accumulation_fn here.
      v = cross_tower_utils.aggregate_tensors_or_indexed_slices(
          v_list, math_ops.add_n)
    else:
      count += 1
    all_values.append(v)
  if not all_values:
    raise ValueError("`per_device_value` must be non-empty")

  with ops.device(reduce_to_device):
    with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
      reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
          all_values, accumulation_fn)
      if aggregation == vs.VariableAggregation.MEAN:
        reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
            reduced, count)
      elif aggregation != vs.VariableAggregation.SUM:
        raise ValueError("`aggregation` must be VariableAggregation.SUM "
                         "or VariableAggregation.MEAN.")
  return reduced
예제 #2
0
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
                   aggregation):
  # pylint: disable=g-missing-docstring
  all_values = []
  count = 0
  for v in per_replica_value._index.values():  # pylint: disable=protected-access
    if isinstance(v, value_lib.MapOutput):
      v_list = v.get()
      if not v_list:
        continue
      count += len(v_list)
      # Sum within each device before aggregating across devices.
      # TODO(yuefengz): Check whether it helps to use accumulation_fn here.
      v = cross_tower_utils.aggregate_tensors_or_indexed_slices(
          v_list, math_ops.add_n)
    else:
      count += 1
    all_values.append(v)
  if not all_values:
    raise ValueError("`per_replica_value` must be non-empty")

  with ops.device(reduce_to_device):
    with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
      reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
          all_values, accumulation_fn)
      if aggregation == vs.VariableAggregation.MEAN:
        reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
            reduced, count)
      elif aggregation != vs.VariableAggregation.SUM:
        raise ValueError("`aggregation` must be VariableAggregation.SUM "
                         "or VariableAggregation.MEAN.")
  return reduced
 def testDivideIndexedSlices(self):
   t = math_ops._as_indexed_slices(
       constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
   n = 2
   expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
   result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
   self.assertIsInstance(result, ops.IndexedSlices)
   self._assert_values_equal(expected, result)
 def testDivideIndexedSlices(self):
     t = math_ops._as_indexed_slices(
         constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
     n = 2
     expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
     result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
     self.assertIsInstance(result, ops.IndexedSlices)
     self._assert_values_equal(expected, result)
예제 #5
0
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
                   reduce_op):
    # pylint: disable=g-missing-docstring
    all_values = []
    count = 0
    for v in per_replica_value._index.values():  # pylint: disable=protected-access
        count += 1
        all_values.append(v)
    if not all_values:
        raise ValueError("`per_replica_value` must be non-empty")

    with ops.device(reduce_to_device):
        with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
            reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
                all_values, accumulation_fn)
            if reduce_op == reduce_util.ReduceOp.MEAN:
                reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
                    reduced, count)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise ValueError(
                    "`reduce_op` must be Reduce.SUM or Reduce.MEAN.")
    return reduced
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
                   aggregation):
  # pylint: disable=g-missing-docstring
  all_values = []
  count = 0
  for v in per_replica_value._index.values():  # pylint: disable=protected-access
    count += 1
    all_values.append(v)
  if not all_values:
    raise ValueError("`per_replica_value` must be non-empty")

  with ops.device(reduce_to_device):
    with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
      reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
          all_values, accumulation_fn)
      if aggregation == vs.VariableAggregation.MEAN:
        reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
            reduced, count)
      elif aggregation != vs.VariableAggregation.SUM:
        raise ValueError("`aggregation` must be VariableAggregation.SUM "
                         "or VariableAggregation.MEAN.")
  return reduced
 def testDivideTensor(self):
   t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
   n = 2
   expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
   result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
   self._assert_values_equal(expected, result)
 def testDivideTensor(self):
     t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
     n = 2
     expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]])
     result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
     self._assert_values_equal(expected, result)