def _concatenate_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType: assert len(tensors) > 0 for tensor in tensors: assert isinstance(tensor, TensorLike) utils.check_same_dtype(tensors) utils.check_same_device(tensors, allow_scalars=False) shape = tensors[0].shape utils.validate_idx(shape, dim) # Verifies same shape (except in the concat dimension) concat_length = 0 for tensor in tensors: for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)): if idx == dim: concat_length = concat_length + length else: assert length == common_length new_shape = list(tensors[0].shape).copy() new_shape[dim] = concat_length return TensorMeta( tensors[0], shape=new_shape, strides=utils.make_contiguous_strides_for(new_shape), )
def _concatenate_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType: if len(tensors) == 0: msg = "concatenate expects at least one tensor, but received zero!" raise ValueError(msg) for tensor in tensors: assert isinstance(tensor, TensorLike) utils.check_same_dtype(*tensors) utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False) shape = tensors[0].shape utils.validate_idx(tensors[0].ndim, dim) # Verifies same shape (except in the concat dimension) concat_length = 0 for tensor in tensors: for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)): if idx == dim: concat_length = concat_length + length else: assert length == common_length new_shape = list(tensors[0].shape).copy() new_shape[dim] = concat_length return TensorMeta( tensors[0], shape=new_shape, strides=utils.make_contiguous_strides_for(new_shape), )
def _select_meta(pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: utils.check_same_device(pred, a, b, allow_scalars=True) utils.check_same_shape(pred, a, b) assert pred.dtype is torch.bool return _elementwise_meta(a, b)
def _select_meta(pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True) utils.check_same_shape(pred, a, b, allow_cpu_scalar_tensors=True) assert pred.dtype is torch.bool return _elementwise_meta( a, b, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT)
def _elementwise_meta(*args, type_promotion): """ Meta function for elementwise operations that produce outputs in the same dtype as their inputs. Stride logic is currently incorrect. """ assert len(args) > 0 utils.check_same_device(*args, allow_cpu_scalar_tensors=True) utils.check_same_shape(*args, allow_cpu_scalar_tensors=True) utils.check_same_dtype(*args) strides = None tensor = None number = None for arg in args: if isinstance(arg, TensorLike): if strides is None: strides = arg.stride() if tensor is None: tensor = arg if arg.stride() != strides: return TensorMeta(arg, strides=utils.make_contiguous_strides_for( arg.shape)) elif isinstance(arg, Number): if number is None: number = arg # TODO: fix strides if tensor is not None: if 0 in tensor.stride() and tensor.numel() > 0: return TensorMeta(tensor, strides=utils.make_contiguous_strides_for( tensor.shape)) else: return TensorMeta(tensor) return TensorMeta(number)