def _aggregate_grads(gradients): """Aggregate gradients from multiple sources. Args: gradients: A list of 'Tensor' or 'IndexedSlices' gradients. Returns: If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'. Otherwise returns an aggregated 'IndexedSlices'. """ assert gradients, "No gradients to aggregate" if len(gradients) == 1: return gradients[0] if all(isinstance(g, ops.Tensor) for g in gradients): return gen_math_ops.add_n(gradients) else: assert all( isinstance(g, (ops.Tensor, ops.IndexedSlices)) for g in gradients) return aggregate_indexed_slices_gradients(gradients)
def _aggregate_grads(gradients): """Aggregate gradients from multiple sources. Args: gradients: A list of 'Tensor' or 'IndexedSlices' gradients. Returns: If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'. Otherwise returns an aggregated 'IndexedSlices'. """ assert gradients, "No gradients to aggregate" if len(gradients) == 1: return gradients[0] if all([isinstance(g, ops.Tensor) for g in gradients]): return gen_math_ops.add_n(gradients) else: assert all([ isinstance(g, (ops.Tensor, ops.IndexedSlices)) for g in gradients ]) indexed_slices_list = [] for grad in gradients: # TODO(xpan): Support nested IndexedSlices and core IndexedSlices if isinstance(grad, ops.Tensor): indexed_slices = ops.IndexedSlices( grad, math_ops.range(grad.shape[0]), constant_op.constant(grad.shape.as_list())) indexed_slices_list.append(indexed_slices) else: indexed_slices_list.append(grad) # Dense shapes from all gradients should be the same. dense_shape = indexed_slices_list[0].dense_shape # For simplicity now, always cast to int64. indices = array_ops.concat([ math_ops.cast(x.indices, dtypes.int64) for x in indexed_slices_list ], 0) values = array_ops.concat([x.values for x in indexed_slices_list], 0) return ops.IndexedSlices(values, indices, dense_shape)
def _aggregate_grads(gradients): """Aggregate gradients from multiple sources. Args: gradients: A list of 'Tensor' or 'IndexedSlices' gradients. Returns: If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'. Otherwise returns an aggregated 'IndexedSlices'. """ assert gradients, "No gradients to aggregate" if len(gradients) == 1: return gradients[0] if all(isinstance(g, ops.Tensor) for g in gradients): return gen_math_ops.add_n(gradients) else: assert all(isinstance(g, (ops.Tensor, ops.IndexedSlices)) for g in gradients) indexed_slices_list = [] for grad in gradients: # TODO(xpan): Support nested IndexedSlices and core IndexedSlices if isinstance(grad, ops.Tensor): indexed_slices = ops.IndexedSlices( grad, math_ops.range(grad.shape[0]), constant_op.constant(grad.shape.as_list())) indexed_slices_list.append(indexed_slices) else: indexed_slices_list.append(grad) # Dense shapes from all gradients should be the same. dense_shape = indexed_slices_list[0].dense_shape # For simplicity now, always cast to int64. indices = array_ops.concat([math_ops.cast(x.indices, dtypes.int64) for x in indexed_slices_list], 0) values = array_ops.concat([x.values for x in indexed_slices_list], 0) return ops.IndexedSlices(values, indices, dense_shape)