def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" if any(isinstance(v, ops.IndexedSlices) for v in values): return backprop.aggregate_indexed_slices_gradients(values) else: return accumulation_fn(values)
def _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method=None): """Get the aggregated gradients for op. Args: grads: The map of memoized gradients. op: The op to get gradients for. gradient_uid: A unique identifier within the graph indicating which invocation of gradients is being executed. Used to cluster ops for compilation. loop_state: An object for maintaining the state of the while loops in the graph. It is of type ControlFlowState. None if the graph contains no while loops. aggregation_method: Specifies the method used to combine gradient terms. Accepted values are constants defined in the class `AggregationMethod`. Returns: A list of gradients, one per each output of `op`. If the gradients for a particular output is a list, this function aggregates it before returning. Raises: TypeError: if the incoming grads are not Tensors or IndexedSlices. ValueError: if the arguments are invalid. """ if aggregation_method is None: aggregation_method = AggregationMethod.DEFAULT if aggregation_method not in [ AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE, AggregationMethod.EXPERIMENTAL_ACCUMULATE_N ]: raise ValueError("Invalid aggregation_method specified %s." % aggregation_method) out_grads = _GetGrads(grads, op) for i, out_grad in enumerate(out_grads): if loop_state: if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)): assert control_flow_util.IsLoopSwitch(op) continue # Grads have to be Tensors or IndexedSlices if (isinstance(out_grad, collections_abc.Sequence) and not all( isinstance(g, (ops.Tensor, ops.IndexedSlices)) for g in out_grad if g is not None)): raise TypeError("gradients have to be either all Tensors " "or all IndexedSlices") # Aggregate multiple gradients, and convert [] to None. if out_grad: if len(out_grad) < 2: used = "nop" out_grads[i] = out_grad[0] elif all( isinstance(g, ops.Tensor) for g in out_grad if g is not None): tensor_shape = _AccumulatorShape(out_grad) if aggregation_method in [ AggregationMethod.EXPERIMENTAL_TREE, AggregationMethod.EXPERIMENTAL_ACCUMULATE_N ]: # Aggregate all gradients by doing pairwise sums: this may # reduce performance, but it can improve memory because the # gradients can be released earlier. # # TODO(vrv): Consider replacing this with a version of # tf.AddN() that eagerly frees its inputs as soon as they are # ready, so the order of this tree does not become a problem. used = "tree" with ops.name_scope(op.name + "_gradient_sum"): running_sum = out_grad[0] for grad in out_grad[1:]: running_sum = math_ops.add_n([running_sum, grad]) out_grads[i] = running_sum else: used = "add_n" out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid) logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad), tensor_shape, used) else: out_grads[i] = backprop.aggregate_indexed_slices_gradients( out_grad) # pylint: disable=protected-access else: # not out_grad # out_grads[i] is [], thus its aggregation is simply None. out_grads[i] = None return out_grads