def range_dimension(self): """Dimension (in the sense of vector spaces) of the range of this operator. If this operator acts like the batch matrix `A` with `tensor_shape.TensorShape(A.shape) = [B1,...,Bb, M, N]`, then this returns `M`. Returns: `Dimension` object. """ # Derived classes get this "for free" once .shape is implemented. if tensor_shape.TensorShape(self.shape).dims: return tensor_shape.TensorShape(self.shape).dims[-2] else: return tensor_shape.Dimension(None)
def _check_shapes(self): """Static check that shapes are compatible.""" # Broadcast shape also checks that u and v are compatible. uv_shape = _ops.broadcast_static_shape( tensor_shape.TensorShape(self.u.shape), tensor_shape.TensorShape(self.v.shape)) batch_shape = _ops.broadcast_static_shape( self.base_operator.batch_shape, uv_shape[:-2]) tensor_shape.Dimension( self.base_operator.domain_dimension).assert_is_compatible_with( uv_shape[-2]) if self._diag_update is not None: tensor_shape.dimension_at_index(uv_shape, -1).assert_is_compatible_with( tensor_shape.TensorShape(self._diag_update.shape)[-1]) _ops.broadcast_static_shape( batch_shape, tensor_shape.TensorShape(self._diag_update.shape)[:-1])