def _unblockify_then_matricize(self, vec): """Flatten the block dimensions then reshape to a batch matrix.""" # Suppose # tensor_shape.TensorShape(vec.shape) = [v0, v1, v2, v3], # self.block_depth = 2. # Then # leading shape = [v0, v1] # block shape = [v2, v3]. # We will reshape vec to # [v1, v2*v3, v0]. # Un-blockify: Flatten block dimensions. Reshape # [v0, v1, v2, v3] --> [v0, v1, v2*v3]. if tensor_shape.TensorShape(vec.shape).is_fully_defined(): # vec_shape = [v0, v1, v2, v3] vec_shape = tensor_shape.TensorShape(vec.shape).as_list() # vec_leading_shape = [v0, v1] vec_leading_shape = vec_shape[:-self.block_depth] # vec_block_shape = [v2, v3] vec_block_shape = vec_shape[-self.block_depth:] # flat_shape = [v0, v1, v2*v3] flat_shape = vec_leading_shape + [np.prod(vec_block_shape)] else: vec_shape = array_ops.shape(vec) vec_leading_shape = vec_shape[:-self.block_depth] vec_block_shape = vec_shape[-self.block_depth:] flat_shape = array_ops.concat( (vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]), 0) vec_flat = array_ops.reshape(vec, flat_shape) # Matricize: Reshape to batch matrix. # [v0, v1, v2*v3] --> [v1, v2*v3, v0], # representing a shape [v1] batch of [v2*v3, v0] matrices. matrix = distribution_util.rotate_transpose(vec_flat, shift=-1) return matrix
def _shape_tensor(self): # See _ops.TensorShape(self.shape) for explanation of steps s_shape = array_ops.shape(self._spectrum) batch_shape = s_shape[:-self.block_depth] trailing_dims = s_shape[-self.block_depth:] n = math_ops.reduce_prod(trailing_dims) n_x_n = [n, n] return array_ops.concat((batch_shape, n_x_n), 0)
def _shape_tensor(self, spectrum=None): spectrum = self.spectrum if spectrum is None else spectrum # See tensor_shape.TensorShape(self.shape) for explanation of steps s_shape = prefer_static.shape(spectrum) batch_shape = s_shape[:-self.block_depth] trailing_dims = s_shape[-self.block_depth:] n = math_ops.reduce_prod(trailing_dims) n_x_n = [n, n] return prefer_static.concat((batch_shape, n_x_n), 0)
def _determinant(self): axis = [-(i + 1) for i in range(self.block_depth)] det = math_ops.reduce_prod(self.spectrum, axis=axis) return _ops.cast(det, self.dtype)
def _determinant(self): return math_ops.reduce_prod(self._get_diag(), axis=[-1])