def _concatenate_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType: if len(tensors) == 0: msg = "concatenate expects at least one tensor, but received zero!" raise ValueError(msg) for tensor in tensors: assert isinstance(tensor, TensorLike) utils.check_same_dtype(*tensors) utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False) shape = tensors[0].shape utils.validate_idx(tensors[0].ndim, dim) # Verifies same shape (except in the concat dimension) concat_length = 0 for tensor in tensors: for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)): if idx == dim: concat_length = concat_length + length else: assert length == common_length new_shape = list(tensors[0].shape).copy() new_shape[dim] = concat_length return TensorMeta( tensors[0], shape=new_shape, strides=utils.make_contiguous_strides_for(new_shape), )
def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType: assert isinstance(a, TensorLike) utils.validate_idx(a.ndim, dim) utils.validate_dim_length(outer_length) # Verifies the dim can be split with the specified lhs_length _inner_length = a.shape[dim] / outer_length inner_length: int = int(_inner_length) if inner_length != _inner_length: msg = "Attempting to split dimension of length {0}, but outer length of {1} divides it with a remainder!".format( a.shape[dim], outer_length) raise ValueError(msg) new_shape: List[int] = [] new_strides: List[int] = [] for idx in range(a.ndim): if idx == dim: new_shape.extend((outer_length, inner_length)) new_strides.extend( (a.stride()[idx] * inner_length, a.stride()[idx])) else: new_shape.append(a.shape[idx]) new_strides.append(a.stride()[idx]) return TensorMeta(a, shape=new_shape, strides=new_strides)
def _concatenate_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType: assert len(tensors) > 0 for tensor in tensors: assert isinstance(tensor, TensorLike) utils.check_same_dtype(tensors) utils.check_same_device(tensors, allow_scalars=False) shape = tensors[0].shape utils.validate_idx(shape, dim) # Verifies same shape (except in the concat dimension) concat_length = 0 for tensor in tensors: for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)): if idx == dim: concat_length = concat_length + length else: assert length == common_length new_shape = list(tensors[0].shape).copy() new_shape[dim] = concat_length return TensorMeta( tensors[0], shape=new_shape, strides=utils.make_contiguous_strides_for(new_shape), )
def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType: assert isinstance(a, TensorLike) shape = a.shape strides = a.stride() utils.validate_idx(shape, start) utils.validate_exclusive_idx(shape, end) # Verifies end is strictly greater than start # (Collapse requires a non-empty interval) assert end > start length = 1 stride = 1 for idx in range(start, end): if idx != (end - 1): assert strides[idx] == strides[idx + 1] * shape[idx + 1] length = length * shape[idx] stride = stride * strides[idx] new_shape = shape[:start] + (length, ) + shape[end:] new_strides = strides[:start] + (stride, ) + shape[end:] return TensorMeta(a, shape=new_shape, strides=new_strides)
def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType: assert isinstance(a, TensorLike) for idx in dimensions: utils.validate_idx(a.ndim, idx) assert a.shape[idx] == 1 new_shape = [] new_strides = [] for idx in range(len(a.shape)): if idx in dimensions: continue new_shape.append(a.shape[idx]) new_strides.append(a.stride()[idx]) return TensorMeta(a, shape=new_shape, strides=new_strides)
def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType: assert isinstance(a, TensorLike) utils.validate_idx(a.shape, dim) utils.validate_dim_length(outer_length) # Verifies the dim can be split with the specified lhs_length _inner_length = a.shape[dim] / outer_length inner_length: int = int(_inner_length) assert inner_length == _inner_length new_shape: List[int] = [] new_strides: List[int] = [] for idx in a.shape: if idx == dim: new_shape.extend((outer_length, inner_length)) new_strides.extend( (a.stride()[idx] * inner_length, a.stride()[idx])) else: new_shape.append(a.shape[idx]) new_strides.append(a.stride()[idx]) return TensorMeta(a, shape=new_shape, strides=new_strides)
def _collapse_view_helper( a: TensorLikeType, start: int, end: int) -> Tuple[Optional[ShapeType], Optional[StrideType]]: assert isinstance(a, TensorLike) # Special-case for zero dimensional tensors if a.ndim == 0: shape = (1, ) strides = (1, ) else: shape = a.shape # type: ignore[assignment] strides = a.stride() utils.validate_idx(len(shape), start) utils.validate_exclusive_idx(len(shape), end) # Verifies end is strictly greater than start # (Collapse requires a non-empty interval) if end <= start: msg = "Attempting to collapse but end, {0}, is less than or equal to start, {1}!".format( end, start) raise ValueError(msg) length = 1 stride = 1 for idx in range(start, end): if idx != (end - 1): if not (strides[idx] == strides[idx + 1] * shape[idx + 1]): return None, None length = length * shape[idx] stride = stride * strides[idx] new_shape = shape[:start] + (length, ) + shape[end:] new_strides = strides[:start] + (stride, ) + shape[end:] return new_shape, new_strides