Example #1
0
def test_slice_default_steps(x, init):
    starts = torch.tensor([1], dtype=torch.int64)
    ends = torch.tensor([9], dtype=torch.int64)
    axes = torch.tensor([1], dtype=torch.int64)
    y = x[:, 1:9]

    if init:
        op = Slice(axes, starts, ends)
        assert torch.equal(op(x), y)
    else:
        op = Slice()
        assert torch.equal(op(x, starts, ends, axes), y)
Example #2
0
def test_slice_default_axes(x, init):
    starts = torch.tensor([1, 2], dtype=torch.int64)
    ends = torch.tensor([9, 5], dtype=torch.int64)
    steps = torch.tensor([1, 2], dtype=torch.int64)
    y = x[1:9, 2:5:2]

    if init:
        op = Slice(starts=starts, ends=ends, steps=steps)
        assert torch.equal(op(x), y)
    else:
        op = Slice()
        assert torch.equal(op(x, starts, ends, steps=steps), y)
Example #3
0
def test_slice_1(x, init):
    starts = torch.tensor([0, 0], dtype=torch.int64)
    ends = torch.tensor([3, 10], dtype=torch.int64)
    axes = torch.tensor([0, 1], dtype=torch.int64)
    steps = torch.tensor([1, 1], dtype=torch.int64)
    y = x[0:3, 0:10]

    if init:
        op = Slice(axes, starts, ends, steps)
        assert torch.equal(op(x), y)
    else:
        op = Slice()
        assert torch.equal(op(x, starts, ends, axes, steps), y)
Example #4
0
def test_slice_neg_axes(x, init):
    starts = torch.tensor([1], dtype=torch.int64)
    ends = torch.tensor([4], dtype=torch.int64)
    axes = torch.tensor([-1], dtype=torch.int64)
    steps = torch.tensor([2], dtype=torch.int64)
    y = x[:, :, 1:4:2]

    if init:
        op = Slice(axes, starts, ends, steps)
        assert torch.equal(op(x), y)
    else:
        op = Slice()
        assert torch.equal(op(x, starts, ends, axes, steps), y)
Example #5
0
def test_slice_neg_steps(x, init):
    starts = torch.tensor([20, 10, 4], dtype=torch.int64)
    ends = torch.tensor([0, 0, 1], dtype=torch.int64)
    axes = torch.tensor([0, 1, 2], dtype=torch.int64)
    steps = torch.tensor([-1, -3, -2], dtype=torch.int64)
    y = torch.tensor(np.copy(x.numpy()[20:0:-1, 10:0:-3, 4:1:-2]))

    if init:
        op = Slice(axes, starts=starts, ends=ends, steps=steps)
        print(op, flush=True)
        assert torch.equal(op(x), y)
    else:
        op = Slice()
        assert torch.equal(op(x, starts, ends, axes, steps), y)