def _shape(self): matrix_shape = tensor_shape.TensorShape( (self._num_rows_static, self._num_columns_static)) if self._batch_shape_arg is None: return matrix_shape batch_shape = tensor_shape.TensorShape(self._batch_shape_static) return batch_shape.concatenate(matrix_shape)
def _shape(self): s_shape = _ops.TensorShape(self._spectrum.shape) # Suppose _ops.TensorShape(spectrum.shape) = [a, b, c, d] # block_depth = 2 # Then: # batch_shape = [a, b] # N = c*d # and we want to return # [a, b, c*d, c*d] batch_shape = s_shape[:-self.block_depth] # trailing_dims = [c, d] trailing_dims = s_shape[-self.block_depth:] if trailing_dims.is_fully_defined(): n = np.prod(trailing_dims.as_list()) else: n = None n_x_n = tensor_shape.TensorShape([n, n]) return batch_shape.concatenate(n_x_n)
def _shape(self): # Get final matrix shape. domain_dimension = self.operators[0].domain_dimension range_dimension = self.operators[0].range_dimension for operator in self.operators[1:]: domain_dimension += operator.domain_dimension range_dimension += operator.range_dimension matrix_shape = tensor_shape.TensorShape( [domain_dimension, range_dimension]) # Get broadcast batch shape. # broadcast_shape checks for compatibility. batch_shape = self.operators[0].batch_shape for operator in self.operators[1:]: batch_shape = common_shapes.broadcast_shape( batch_shape, operator.batch_shape) return batch_shape.concatenate(matrix_shape)
def _shape(self): matrix_shape = tensor_shape.TensorShape( (self._num_rows_static, self._num_rows_static)) batch_shape = _ops.TensorShape(self.multiplier.shape) return batch_shape.concatenate(matrix_shape)