def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType: start_dim = utils.canonicalize_dim(a.ndim, start_dim) end_dim = utils.canonicalize_dim(a.ndim, end_dim) + 1 # Tries to take a view # TODO: we could look at directing collapse_view to skip its meta function here new_shape, new_strides = prims._collapse_view_helper(a, start_dim, end_dim) if new_shape is not None: return prims.collapse_view(a, start_dim, end_dim) # Makes a copy if it can't make a view result = prims.collapse(a, start_dim, end_dim) return result
def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType: start_dim = utils.canonicalize_dim(a.ndim, start_dim) end_dim = utils.canonicalize_dim(a.ndim, end_dim) # Short-circuits on no-op if start_dim == end_dim and a.ndim != 0: return a # Tries to take a view # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view) new_shape, new_strides = prims._collapse_view_helper( a, start_dim, end_dim + 1) if new_shape is not None: return prims.collapse_view(a, start_dim, end_dim + 1) # Makes a copy if it can't make a view return prims.collapse(a, start_dim, end_dim + 1)
def stack(tensors: List[Tensor], dim: int = 0) -> Tensor: assert len(tensors) > 0, "stack expects a non-empty TensorList" wrapped_dim = utils.canonicalize_dim(tensors[0].dim() + 1, dim) if wrapped_dim < tensors[0].dim() and not tensors[0].is_sparse: check_stack_inputs(tensors) result_sizes = list(tensors[0].shape) result_sizes.insert(wrapped_dim, len(tensors)) out = torch.cat(tensors, wrapped_dim) return out.view(result_sizes) else: return torch.cat(get_stack_inputs(tensors, wrapped_dim), dim)
def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor: assert self.dim() > 0, "glu does not support 0-dimensional tensors" wrap_dim = utils.canonicalize_dim(self.dim(), dim) nIn = self.size(wrap_dim) assert nIn % 2 == 0, f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}" inputSize = nIn // 2 firstHalf = self.narrow(wrap_dim, 0, inputSize) secondHalf = self.narrow(wrap_dim, inputSize, inputSize) gradInputFirstHalf = torch.sigmoid(secondHalf) gradInputSecondHalf = (1.0 - gradInputFirstHalf ) * gradInputFirstHalf * firstHalf * grad_output gradInputFirstHalf = gradInputFirstHalf * grad_output return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim)
def squeeze(a: TensorLikeType, dim: Optional[int] = None) -> TensorLikeType: if dim is not None: dim = utils.canonicalize_dim(a.ndim, dim) # Short-circuits if the tensor has no dimensions if len(a.shape) == 0: assert dim == 0 return prims.view_of(a) # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1 if a.shape[dim] != 1: return prims.view_of(a) return prims.squeeze(a, (dim,)) dims = tuple(idx for idx in range(len(a.shape)) if a.shape[idx] == 1) return prims.squeeze(a, dims)
def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]: if chunks <= 0: msg = "Expected at least one chunk, but got {0}!".format(chunks) raise ValueError(msg) dim = utils.canonicalize_dim(a.ndim, dim) length = a.shape[dim] chunk_size = math.ceil(length / chunks) full_chunks = math.floor(length / chunk_size) tail_chunk_size = length % chunk_size result = [] for i in range(full_chunks): result.append(narrow(a, dim, i * chunk_size, chunk_size)) if tail_chunk_size != 0: result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size)) return tuple(result)
def tensor_split( a: TensorLikeType, indices_or_sections: Union[Tensor, DimsType], dim: int = 0, ) -> Tuple[TensorLikeType, ...]: _dim = utils.canonicalize_dim(a.ndim, dim) if a.ndim == 0: msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!" raise ValueError(msg) # If indices_or_sections is a tensor, it must be a CPU Long tensor if isinstance(indices_or_sections, TensorLike): if indices_or_sections.device != torch.device("cpu"): msg = "tensor_split: if indices_or_sections is a tensor it must be on the CPU, but received one on {0}".format( indices_or_sections.device ) raise ValueError(msg) if indices_or_sections.dtype != torch.long: msg = "tensor_split: if indices_or_sections is a tensor it must have long dtype, " " but received one with dtype {0}".format(indices_or_sections.dtype) raise ValueError(msg) # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length if isinstance(indices_or_sections, int) or ( isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0 ): sections: int = ( indices_or_sections # type: ignore[assignment] if isinstance(indices_or_sections, Number) else indices_or_sections.item() ) if sections <= 0: msg = "tensor_split: number of sections must be greater than 0, but was {0}".format( sections ) raise ValueError(msg) splits = [] dim_size = a.shape[_dim] min_split_size = math.floor(dim_size / sections) num_splits_one_extra = dim_size % sections start_idx = 0 for split_idx in range(sections): split_size = ( min_split_size + 1 if (split_idx < num_splits_one_extra) else min_split_size ) s = prims.slice_in_dim(a, start_idx, start_idx + split_size, axis=_dim) splits.append(s) start_idx = start_idx + split_size return tuple(splits) # Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits else: indices = indices_or_sections if isinstance(indices_or_sections, TensorLike): if indices_or_sections.ndim != 1: msg = "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, " "but received a tensor with {0} dimensions".format( indices_or_sections.ndim ) raise ValueError(msg) indices = indices_or_sections.tolist() splits = [] start_idx = 0 for x in indices: splits.append(prims.slice_in_dim(a, start_idx, x, axis=_dim)) start_idx = x splits.append(prims.slice_in_dim(a, start_idx, a.shape[_dim], axis=_dim)) return tuple(splits)
def narrow(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeType: dim = utils.canonicalize_dim(a.ndim, dim) return prims.slice_in_dim(a, start, start + length, axis=dim)
def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType: # Note that unsqueeze canonicalizes with rank + 1 because it allows # a new innermost dimension to be specified dim = utils.canonicalize_dim(a.ndim + 1, dim) return prims.expand_dims(a, (dim,))