Ejemplo n.º 1
0
def _concatenate_meta(tensors: Sequence[TensorLikeType],
                      dim: int) -> TensorLikeType:
    assert len(tensors) > 0

    for tensor in tensors:
        assert isinstance(tensor, TensorLike)

    utils.check_same_dtype(tensors)
    utils.check_same_device(tensors, allow_scalars=False)

    shape = tensors[0].shape
    utils.validate_idx(shape, dim)

    # Verifies same shape (except in the concat dimension)
    concat_length = 0
    for tensor in tensors:
        for idx, (common_length, length) in enumerate(zip(shape,
                                                          tensor.shape)):
            if idx == dim:
                concat_length = concat_length + length
            else:
                assert length == common_length

    new_shape = list(tensors[0].shape).copy()
    new_shape[dim] = concat_length
    return TensorMeta(
        tensors[0],
        shape=new_shape,
        strides=utils.make_contiguous_strides_for(new_shape),
    )
Ejemplo n.º 2
0
def _concatenate_meta(tensors: Sequence[TensorLikeType],
                      dim: int) -> TensorLikeType:
    if len(tensors) == 0:
        msg = "concatenate expects at least one tensor, but received zero!"
        raise ValueError(msg)

    for tensor in tensors:
        assert isinstance(tensor, TensorLike)

    utils.check_same_dtype(*tensors)
    utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)

    shape = tensors[0].shape
    utils.validate_idx(tensors[0].ndim, dim)

    # Verifies same shape (except in the concat dimension)
    concat_length = 0
    for tensor in tensors:
        for idx, (common_length, length) in enumerate(zip(shape,
                                                          tensor.shape)):
            if idx == dim:
                concat_length = concat_length + length
            else:
                assert length == common_length

    new_shape = list(tensors[0].shape).copy()
    new_shape[dim] = concat_length
    return TensorMeta(
        tensors[0],
        shape=new_shape,
        strides=utils.make_contiguous_strides_for(new_shape),
    )
Ejemplo n.º 3
0
def _select_meta(pred: TensorLikeType, a: TensorLikeType,
                 b: TensorLikeType) -> TensorLikeType:
    utils.check_same_device(pred, a, b, allow_scalars=True)
    utils.check_same_shape(pred, a, b)
    assert pred.dtype is torch.bool

    return _elementwise_meta(a, b)
Ejemplo n.º 4
0
def _select_meta(pred: TensorLikeType, a: TensorLikeType,
                 b: TensorLikeType) -> TensorLikeType:
    utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
    utils.check_same_shape(pred, a, b, allow_cpu_scalar_tensors=True)
    assert pred.dtype is torch.bool

    return _elementwise_meta(
        a, b, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT)
Ejemplo n.º 5
0
def _elementwise_meta(*args, type_promotion):
    """
    Meta function for elementwise operations that produce outputs in the same dtype
    as their inputs.

    Stride logic is currently incorrect.
    """

    assert len(args) > 0

    utils.check_same_device(*args, allow_cpu_scalar_tensors=True)
    utils.check_same_shape(*args, allow_cpu_scalar_tensors=True)
    utils.check_same_dtype(*args)

    strides = None
    tensor = None
    number = None
    for arg in args:
        if isinstance(arg, TensorLike):
            if strides is None:
                strides = arg.stride()

            if tensor is None:
                tensor = arg

            if arg.stride() != strides:
                return TensorMeta(arg,
                                  strides=utils.make_contiguous_strides_for(
                                      arg.shape))
        elif isinstance(arg, Number):
            if number is None:
                number = arg

    # TODO: fix strides
    if tensor is not None:
        if 0 in tensor.stride() and tensor.numel() > 0:
            return TensorMeta(tensor,
                              strides=utils.make_contiguous_strides_for(
                                  tensor.shape))
        else:
            return TensorMeta(tensor)

    return TensorMeta(number)