def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, aggregation): # pylint: disable=g-missing-docstring all_values = [] count = 0 for v in per_device_value._index.values(): # pylint: disable=protected-access if isinstance(v, value_lib.MapOutput): v_list = v.get() if not v_list: continue count += len(v_list) # Sum within each device before aggregating across devices. # TODO(yuefengz): Check whether it helps to use accumulation_fn here. v = cross_tower_utils.aggregate_tensors_or_indexed_slices( v_list, math_ops.add_n) else: count += 1 all_values.append(v) if not all_values: raise ValueError("`per_device_value` must be non-empty") with ops.device(reduce_to_device): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( all_values, accumulation_fn) if aggregation == vs.VariableAggregation.MEAN: reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( reduced, count) elif aggregation != vs.VariableAggregation.SUM: raise ValueError("`aggregation` must be VariableAggregation.SUM " "or VariableAggregation.MEAN.") return reduced
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, aggregation): # pylint: disable=g-missing-docstring all_values = [] count = 0 for v in per_replica_value._index.values(): # pylint: disable=protected-access if isinstance(v, value_lib.MapOutput): v_list = v.get() if not v_list: continue count += len(v_list) # Sum within each device before aggregating across devices. # TODO(yuefengz): Check whether it helps to use accumulation_fn here. v = cross_tower_utils.aggregate_tensors_or_indexed_slices( v_list, math_ops.add_n) else: count += 1 all_values.append(v) if not all_values: raise ValueError("`per_replica_value` must be non-empty") with ops.device(reduce_to_device): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( all_values, accumulation_fn) if aggregation == vs.VariableAggregation.MEAN: reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( reduced, count) elif aggregation != vs.VariableAggregation.SUM: raise ValueError("`aggregation` must be VariableAggregation.SUM " "or VariableAggregation.MEAN.") return reduced
def testDivideIndexedSlices(self): t = math_ops._as_indexed_slices( constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) n = 2 expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) self.assertIsInstance(result, ops.IndexedSlices) self._assert_values_equal(expected, result)
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, reduce_op): # pylint: disable=g-missing-docstring all_values = [] count = 0 for v in per_replica_value._index.values(): # pylint: disable=protected-access count += 1 all_values.append(v) if not all_values: raise ValueError("`per_replica_value` must be non-empty") with ops.device(reduce_to_device): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( all_values, accumulation_fn) if reduce_op == reduce_util.ReduceOp.MEAN: reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( reduced, count) elif reduce_op != reduce_util.ReduceOp.SUM: raise ValueError( "`reduce_op` must be Reduce.SUM or Reduce.MEAN.") return reduced
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn, aggregation): # pylint: disable=g-missing-docstring all_values = [] count = 0 for v in per_replica_value._index.values(): # pylint: disable=protected-access count += 1 all_values.append(v) if not all_values: raise ValueError("`per_replica_value` must be non-empty") with ops.device(reduce_to_device): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( all_values, accumulation_fn) if aggregation == vs.VariableAggregation.MEAN: reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( reduced, count) elif aggregation != vs.VariableAggregation.SUM: raise ValueError("`aggregation` must be VariableAggregation.SUM " "or VariableAggregation.MEAN.") return reduced
def testDivideTensor(self): t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) n = 2 expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) self._assert_values_equal(expected, result)