def _SumGrad(op, grad): """Gradient for Sum.""" input_shape = array_ops.shape(op.inputs[0]) output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) grad = array_ops.reshape(grad, output_shape_kept_dims) return [array_ops.tile(grad, tile_scaling), None]
def _SumGrad(op, grad): """Gradient for Sum.""" # Fast path for when reducing to a scalar and ndims is known: adds only # Reshape and Tile ops (and possibly a Shape). input_0_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access if input_0_shape is not None: axes = tensor_util.constant_value(op.inputs[1]) if axes is not None: rank = len(input_0_shape) if np.array_equal(axes, np.arange(rank)): # Reduce all dims. grad = array_ops.reshape(grad, [1] * rank) # If shape is not fully defined (but rank is), we use Shape. if None not in input_0_shape: input_shape = input_0_shape else: input_shape = array_ops.shape(op.inputs[0]) return [array_ops.tile(grad, input_shape), None] input_shape = array_ops.shape(op.inputs[0]) # TODO(apassos) remove this once device placement for eager ops makes more # sense. with ops.colocate_with(input_shape): output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) grad = array_ops.reshape(grad, output_shape_kept_dims) return [array_ops.tile(grad, tile_scaling), None]
def _ProdGrad(op, grad): """Gradient for Prod.""" # TODO(kearnes): this gives NaNs for 0s in the input tensor input_shape = array_ops.shape(op.inputs[0]) output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) grad = array_ops.reshape(grad * op.outputs[0], output_shape_kept_dims) grad = math_ops.div(array_ops.tile(grad, tile_scaling), op.inputs[0]) return grad, None
def _SparseReduceSumGrad(op, out_grad): """Similar to gradient for the Sum Op (i.e. tf.reduce_sum()).""" sp_indices = op.inputs[0] sp_shape = op.inputs[2] output_shape_kept_dims = math_ops.reduced_shape(sp_shape, op.inputs[3]) out_grad_reshaped = array_ops.reshape(out_grad, output_shape_kept_dims) scale = sp_shape // math_ops.to_int64(output_shape_kept_dims) # (sparse_indices, sparse_values, sparse_shape, reduction_axes) return (None, array_ops.gather_nd(out_grad_reshaped, sp_indices // scale), None, None)
def _MinOrMaxGrad(op, grad): """Gradient for Min or Max. Amazingly it's precisely the same code.""" input_shape = array_ops.shape(op.inputs[0]) output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) y = op.outputs[0] y = array_ops.reshape(y, output_shape_kept_dims) grad = array_ops.reshape(grad, output_shape_kept_dims) # Compute the number of selected (maximum or minimum) elements in each # reduction dimension. If there are multiple minimum or maximum elements # then the gradient will be divided between them. indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype) num_selected = array_ops.reshape(math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims) return [math_ops.div(indicators, num_selected) * grad, None]
def _SecureMinOrMaxGrad(op, grad): """Gradient for SecureMin or SecureMax.""" input_shape = array_ops.shape(op.inputs[0]) output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) y = op.outputs[0] y = array_ops.reshape(y, output_shape_kept_dims) grad = array_ops.reshape(grad, output_shape_kept_dims) # Compute the number of selected (maximum or minimum) elements in each # reduction dimension. If there are multiple minimum or maximum elements # then the gradient will be divided between them. indicators = SecureEqual(y, op.inputs[0]) num_selected = array_ops.reshape(SecureSum(indicators, op.inputs[1]), output_shape_kept_dims) return [SecureMul(SecureTruediv(indicators, num_selected), grad), None]
def _MinOrMaxGrad(op, grad): """Gradient for Min or Max. Amazingly it's precisely the same code.""" input_shape = array_ops.shape(op.inputs[0]) output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) y = op.outputs[0] y = array_ops.reshape(y, output_shape_kept_dims) grad = array_ops.reshape(grad, output_shape_kept_dims) # Compute the number of selected (maximum or minimum) elements in each # reduction dimension. If there are multiple minimum or maximum elements # then the gradient will be divided between them. indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype) num_selected = array_ops.reshape( math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims) return [math_ops.div(indicators, num_selected) * grad, None]
def _ProdGrad(op, grad): """Gradient for Prod.""" # The gradient can be expressed by dividing the product by each entry of the # input tensor, but this approach can't deal with zeros in the input. # Here, we avoid this problem by composing the output as a product of two # cumprod operations. input_shape = array_ops.shape(op.inputs[0]) # Reshape reduction indices for the case where the parameter is a scalar reduction_indices = array_ops.reshape(op.inputs[1], [-1]) # Expand grad to full input shape output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) grad = array_ops.reshape(grad, output_shape_kept_dims) grad = array_ops.tile(grad, tile_scaling) # Pack all reduced dimensions into a single one, so we can perform the # cumprod ops. If the reduction dims list is empty, it defaults to float32, # so we need to cast here. We put all the shape-related ops on CPU to avoid # copying back and forth, and since listdiff is CPU only. with ops.device("/cpu:0"): rank = array_ops.rank(op.inputs[0]) reduction_indices = (reduction_indices + rank) % rank reduced = math_ops.cast(reduction_indices, dtypes.int32) idx = math_ops.range(0, rank) other, _ = array_ops.setdiff1d(idx, reduced) perm = array_ops.concat([reduced, other], 0) reduced_num = math_ops.reduce_prod( array_ops.gather(input_shape, reduced)) other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other)) permuted = array_ops.transpose(op.inputs[0], perm) permuted_shape = array_ops.shape(permuted) reshaped = array_ops.reshape(permuted, (reduced_num, other_num)) # Calculate product, leaving out the current entry left = math_ops.cumprod(reshaped, axis=0, exclusive=True) right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True) # For complex inputs, the gradient is in the conjugate direction. y = array_ops.reshape( math_ops.conj(left) * math_ops.conj(right), permuted_shape) # Invert the transpose and reshape operations. # Make sure to set the statically known shape information through a reshape. out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm)) return array_ops.reshape(out, input_shape), None
def _ProdGrad(op, grad): """Gradient for Prod.""" # The gradient can be expressed by dividing the product by each entry of the # input tensor, but this approach can't deal with zeros in the input. # Here, we avoid this problem by composing the output as a product of two # cumprod operations. input_shape = array_ops.shape(op.inputs[0]) # Reshape reduction indices for the case where the parameter is a scalar reduction_indices = array_ops.reshape(op.inputs[1], [-1]) # Expand grad to full input shape output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) grad = array_ops.reshape(grad, output_shape_kept_dims) grad = array_ops.tile(grad, tile_scaling) # Pack all reduced dimensions into a single one, so we can perform the # cumprod ops. If the reduction dims list is empty, it defaults to float32, # so we need to cast here. We put all the shape-related ops on CPU to avoid # copying back and forth, and since listdiff is CPU only. with ops.device("/cpu:0"): rank = array_ops.rank(op.inputs[0]) reduction_indices = (reduction_indices + rank) % rank reduced = math_ops.cast(reduction_indices, dtypes.int32) idx = math_ops.range(0, rank) other, _ = array_ops.setdiff1d(idx, reduced) perm = array_ops.concat([reduced, other], 0) reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced)) other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other)) permuted = array_ops.transpose(op.inputs[0], perm) permuted_shape = array_ops.shape(permuted) reshaped = array_ops.reshape(permuted, (reduced_num, other_num)) # Calculate product, leaving out the current entry left = math_ops.cumprod(reshaped, axis=0, exclusive=True) right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True) # For complex inputs, the gradient is in the conjugate direction. y = array_ops.reshape(math_ops.conj(left) * math_ops.conj(right), permuted_shape) # Invert the transpose and reshape operations. # Make sure to set the statically known shape information through a reshape. out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm)) return array_ops.reshape(out, input_shape), None
def _SumGrad(op, grad): """Gradient for Sum.""" # Fast path for when reducing to a scalar and ndims is known: adds only # Reshape and Tile ops (and possibly a Shape). if op.inputs[0].get_shape().ndims is not None and op.inputs[1].op.type == "Const": rank = op.inputs[0].get_shape().ndims axes = tensor_util.MakeNdarray(op.inputs[1].op.get_attr("value")) if np.array_equal(axes, np.arange(rank)): # Reduce all dims. grad = array_ops.reshape(grad, [1] * rank) # If shape is not fully defined (but rank is), we use Shape. if op.inputs[0].get_shape().is_fully_defined(): input_shape = op.inputs[0].get_shape().as_list() else: input_shape = array_ops.shape(op.inputs[0]) return [array_ops.tile(grad, input_shape), None] input_shape = array_ops.shape(op.inputs[0]) output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) grad = array_ops.reshape(grad, output_shape_kept_dims) return [array_ops.tile(grad, tile_scaling), None]
def _SumGrad(op, grad): """Gradient for Sum.""" # Fast path for when reducing to a scalar and ndims is known: adds only # Reshape and Tile ops (and possibly a Shape). if (op.inputs[0].get_shape().ndims is not None and op.inputs[1].op.type == "Const"): rank = op.inputs[0].get_shape().ndims axes = tensor_util.MakeNdarray(op.inputs[1].op.get_attr("value")) if np.array_equal(axes, np.arange(rank)): # Reduce all dims. grad = array_ops.reshape(grad, [1] * rank) # If shape is not fully defined (but rank is), we use Shape. if op.inputs[0].get_shape().is_fully_defined(): input_shape = op.inputs[0].get_shape().as_list() else: input_shape = array_ops.shape(op.inputs[0]) return [array_ops.tile(grad, input_shape), None] input_shape = array_ops.shape(op.inputs[0]) output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) grad = array_ops.reshape(grad, output_shape_kept_dims) return [array_ops.tile(grad, tile_scaling), None]
def _SparseReduceMinOrMaxGrad(op, out_grad): sp_indices = op.inputs[0] sp_values = op.inputs[1] sp_shape = op.inputs[2] reduction_axes = op.inputs[3] output = op.outputs[0] # Handle keepdims output_shape_kept_dims = math_ops.reduced_shape(sp_shape, op.inputs[3]) out_grad = array_ops.reshape(out_grad, output_shape_kept_dims) output = array_ops.reshape(output, output_shape_kept_dims) # Map input and output coefficients scale = sp_shape // math_ops.to_int64(output_shape_kept_dims) scaled_indices = sp_indices // scale # Map pooled values with corresponding max/min values sp_max_val = array_ops.gather_nd(output, scaled_indices) indicators = math_ops.cast(math_ops.equal(sp_values, sp_max_val), out_grad.dtype) grad_values = array_ops.gather_nd(out_grad, scaled_indices) # Compute the number of selected (maximum or minimum) elements in each # reduction dimension. If there are multiple minimum or maximum elements # then the gradient will be divided between them. # (same as for MaxGrad) sp_indicators = sparse_tensor.SparseTensor(sp_indices, indicators, sp_shape) num_selected = array_ops.gather_nd( sparse_ops.sparse_reduce_sum(sp_indicators, axis=reduction_axes, keep_dims=True), scaled_indices) # (input_indices, input_values, input_shape, reduction_axes) return [ None, math_ops.div(indicators, math_ops.maximum(num_selected, 1)) * grad_values, None, None ]
def _ProdGrad(op, grad): """Gradient for Prod.""" # The gradient can be expressed by dividing the product by each entry of the # input tensor, but this approach can't deal with zeros in the input. # Here, we avoid this problem by composing the output as a product of two # cumprod operations. input_shape = array_ops.shape(op.inputs[0]) # Expand grad to full input shape output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]) tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) grad = array_ops.reshape(grad, output_shape_kept_dims) grad = array_ops.tile(grad, tile_scaling) # Pack all reduced dimensions into a single one, so we can perform the # cumprod ops. If the reduction dims list is empty, it defaults to float32, # so we need to cast here. reduced = math_ops.cast(op.inputs[1], dtypes.int32) idx = math_ops.range(0, array_ops.rank(op.inputs[0])) other, _ = array_ops.listdiff(idx, reduced) perm = array_ops.concat(0, [reduced, other]) reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced)) other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other)) permuted = array_ops.transpose(op.inputs[0], perm) permuted_shape = array_ops.shape(permuted) reshaped = array_ops.reshape(permuted, (reduced_num, other_num)) # Calculate product, leaving out the current entry left = math_ops.cumprod(reshaped, axis=0, exclusive=True) right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True) y = array_ops.reshape(left * right, permuted_shape) # Invert the transpose and reshape operations. # Make sure to set the statically known shape information through a reshape. out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm)) return array_ops.reshape(out, input_shape), None
def _SumGrad(op, grad): """Gradient for Sum.""" # Fast path for when reducing to a scalar and ndims is known: adds only # Reshape and Tile ops (and possibly a Shape). input_0_shape = op.inputs[0]._shape_tuple() # pylint: disable=protected-access if input_0_shape is not None: axes = tensor_util.constant_value(op.inputs[1]) if axes is not None: rank = len(input_0_shape) if np.array_equal(axes, np.arange(rank)): # Reduce all dims. if context.executing_eagerly(): ctx = context.context() new_shape = ctx.ones_rank_cache().get(rank) if new_shape is None: new_shape = constant_op.constant([1] * rank, dtype=dtypes.int32) ctx.ones_rank_cache().put(rank, new_shape) else: new_shape = [1] * rank grad = array_ops.reshape(grad, new_shape) # If shape is not fully defined (but rank is), we use Shape. if None not in input_0_shape: input_shape = constant_op.constant(input_0_shape, dtype=dtypes.int32) else: input_shape = array_ops.shape(op.inputs[0]) return [array_ops.tile(grad, input_shape), None] input_shape = array_ops.shape(op.inputs[0]) # TODO(apassos) remove this once device placement for eager ops makes more # sense. with ops.colocate_with(input_shape): output_shape_kept_dims = math_ops.reduced_shape( input_shape, op.inputs[1]) tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) grad = array_ops.reshape(grad, output_shape_kept_dims) return [array_ops.tile(grad, tile_scaling), None]
def _check(self, shape, axes, result): output = math_ops.reduced_shape(shape, axes=axes) self.assertAllEqual(output.eval(), result)
def _GetGradReduced(output_grad, output_subs, input_subs, input_shape, reduced_label_set): """Returns the gradient wrt input for a unary einsum with reductions. Args: output_grad: The gradient wrt the output of a unary einsum operation. output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`). input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`). input_shape: A `Tensor` representing the shape of the input operand. reduced_label_set: The set of axis labels appearing in `input_subs` but not in `output_subs`. """ # Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and # 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced # subscripts "bd", corresponding dimensions [5,4] and axes [2,5]. reduced_subs, reduced_dims, reduced_axes = _GetReducedSubscripts( reduced_label_set, input_shape, input_subs) # Whether either the input or the output subscripts have a repeated label. # This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca". has_repeated_labels = (len(set(input_subs)) + len(set(output_subs)) < len(input_subs) + len(output_subs)) # Compute the input subscripts without the reduced axis labels, e.g. "aac" # for the equation "aabbcd->ca". input_subs_without_reduced_labels = "".join( [s for s in input_subs if s not in reduced_label_set]) # The gradient wrt the input for the equation "abc->ac" (or, equivalently # reduce_sum(..., axis=1)) is just the gradient of the output tiled N times # along axis 1, where label 'b' represents a dimension of size N. # # If we're not dealing with repeated labels, and the non-reduced labels # doesn't need to be transposed, then just tiling is enough and there is no # need to call another einsum. For example, tiling is sufficient for # "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or # "abc->ca" (transpose), we'd need another einsum operation after tiling. if (not has_repeated_labels and input_subs_without_reduced_labels == output_subs): # Obtain the shape of the output, as if keepdims=True on reduce sum. E.g. # for the equation "abcd->ac" with input shape [2,5,3,4], we get the # reduced shape [2,1,3,1]. reduced_shape = math_ops.reduced_shape( input_shape, ops.convert_to_tensor(reduced_axes)) # Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to # the shape [2,5,3,4] results in the gradient wrt "abcd". return array_ops.broadcast_to( array_ops.reshape(output_grad, reduced_shape), input_shape) # If we *do* have traces or transpose operations, then prepend the extra # reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd # first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca". # # Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2]. # This is the shape of the intermediate "bdca". grad_shape_with_reduced_labels = array_ops.concat( [reduced_dims, array_ops.shape(output_grad)], axis=0) # Obtain the output shape of the reduction-only equation "bdca->ca" as if # keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels, we # just have to prepend that many 1s to the output shape. reduced_shape = (array_ops.concat([ array_ops.ones(len(reduced_label_set), dtype=dtypes.int32), array_ops.shape(output_grad) ], axis=0)) # Compute the VJP for the intermediate (viz. "bdca->ca") for which # broadcasting is sufficient. broadcasted_grad = array_ops.broadcast_to( array_ops.reshape(output_grad, reduced_shape), grad_shape_with_reduced_labels) # Compute the VJP for the final step (viz. "aabbcd->bdca"). We can use # einsum with the input and output subscripts reversed (viz. "bdca->aabbcd") # since the output axis labels now appear in the input subscripts. return gen_linalg_ops.einsum([broadcasted_grad], "{}->{}".format( reduced_subs + output_subs, input_subs))