def _MulGrad(op, grad): """The gradient of scalar multiplication.""" x = op.inputs[0] y = op.inputs[1] # pylint: disable=protected-access if (isinstance(grad, ops.Tensor) and _ShapesFullySpecifiedAndEqual(x, y, grad) and grad.dtype in (dtypes.int32, dtypes.float32)): return gen_math_ops._mul(grad, y), gen_math_ops._mul(grad, x) assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype) sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) # pylint: enable=protected-access x = math_ops.conj(x) y = math_ops.conj(y) return (array_ops.reshape(math_ops.reduce_sum(grad * y, rx), sx), array_ops.reshape(math_ops.reduce_sum(x * grad, ry), sy))