def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstring a = asarray(a).data maybe_rank = a.shape.rank if maybe_rank is not None and offset == 0 and ( axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or axis2 == -1): return utils.tensor_to_ndarray(tf.linalg.diag_part(a)) a = moveaxis(utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data a_shape = tf.shape(a) def _zeros(): # pylint: disable=missing-docstring return (tf.zeros(tf.concat([a_shape[:-1], [0]], 0), dtype=a.dtype), 0) # All zeros since diag_part doesn't handle all possible k (aka offset). # Written this way since cond will run shape inference on both branches, # and diag_part shape inference will fail when offset is out of bounds. a, offset = utils.cond( utils.logical_or( utils.less_equal(offset, -1 * utils.getitem(a_shape, -2)), utils.greater_equal(offset, utils.getitem(a_shape, -1)), ), _zeros, lambda: (a, offset)) a = utils.tensor_to_ndarray(tf.linalg.diag_part(a, k=offset)) return a
def diag(v, k=0): # pylint: disable=missing-docstring """Raises an error if input is not 1- or 2-d.""" v = asarray(v).data v_rank = tf.rank(v) v.shape.with_rank_at_most(2) # TODO(nareshmodi): Consider a utils.Assert version that will fail during # tracing time if the shape is known. tf.debugging.Assert( utils.logical_or(tf.equal(v_rank, 1), tf.equal(v_rank, 2)), [v_rank]) def _diag(v, k): return utils.cond( tf.equal(tf.size(v), 0), lambda: tf.zeros([abs(k), abs(k)], dtype=v.dtype), lambda: tf.linalg.diag(v, k=k)) def _diag_part(v, k): v_shape = tf.shape(v) v, k = utils.cond( utils.logical_or( utils.less_equal(k, -1 * utils.getitem(v_shape, 0)), utils.greater_equal(k, utils.getitem(v_shape, 1)), ), lambda: (tf.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k)) result = tf.linalg.diag_part(v, k=k) return result result = utils.cond( tf.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k)) return utils.tensor_to_ndarray(result)
def f(a, b): # pylint: disable=missing-docstring return utils.cond( utils.logical_or(tf.rank(a) == 0, tf.rank(b) == 0), lambda: a * b, lambda: utils.cond( # pylint: disable=g-long-lambda tf.rank(b) == 1, lambda: tf.tensordot(a, b, axes=[[-1], [-1]]), lambda: tf.tensordot(a, b, axes=[[-1], [-2]])))
def _diag_part(v, k): v_shape = tf.shape(v) v, k = utils.cond( utils.logical_or( utils.less_equal(k, -1 * utils.getitem(v_shape, 0)), utils.greater_equal(k, utils.getitem(v_shape, 1)), ), lambda: (tf.zeros([0, 0], dtype=v.dtype), 0), lambda: (v, k)) result = tf.linalg.diag_part(v, k=k) return result
def f(a, b): return utils.cond(utils.logical_or(tf.rank(a) == 0, tf.rank(b) == 0), lambda: a * b, lambda: tf.tensordot(a, b, axes=[[-1], [-1]]))