Esempio n. 1
0
    def __getitem__(self, key) -> 'Tensor':
        """
        Supports slicing and integer indexing.
        If a single index is selected during slicing the dimention will be squeezed - this matches numpy slicing rules.

        Examples:
        ```
        # Slicing
        x[0]        # Select all elements where i==0 for axis 0
        x[0,1]      # Select all elements where i==0, j==1 for axis 0 and 1
        x[0:2]      # Slice axis 0 between index 0 and 2
        x[:2,3:]    # Slice axis 0 upto 2 and axis 1 from index 3
        x[:,::-1]   # Select all elements for axis 0 and reverse axis 1

        # Integer indexing
        indices = Tensor([[0,1], [1,0]], dtype='int32')
        x[indices]  # Select elements [0,1] and [1,0] from `x`
        ```
        """

        import popart.ir.ops as ops

        if isinstance(key, (slice, int)) or (isinstance(key, tuple) and all(
                isinstance(e, (slice, int)) for e in key)):
            # Basic slicing (integer or slices)
            key = (key, ) if isinstance(key, (slice, int)) else key

            start = []
            stop = []
            step = []
            int_slices = []

            for i, key_i in enumerate(key):
                if isinstance(key_i, int):
                    start += [key_i]
                    stop += [key_i + 1]
                    step += [1]
                    int_slices += [i]

                elif isinstance(key_i, slice):
                    start += [key_i.start]
                    stop += [key_i.stop]
                    step += [key_i.step]

            out = ops.slice(self, start, stop, step)

            if len(int_slices) > 0:
                out = ops.squeeze(out, axes=int_slices)

            return out

        elif (isinstance(key, Tensor) and key.dtype.is_int):
            # Integer indexing
            return ops.gather(self, key)

        else:
            raise IndexError(
                "Only integers, slices (`:`) and integer arrays are valid indices."
            )
Esempio n. 2
0
 def test_error_lengths(self, inplace):
     ir = pir.Ir()
     with ir.main_graph():
         t = pir.variable(data)
         with pytest.raises(ValueError):
             if inplace:
                 y = ops.slice_(t, start=[2], stop=[3, 4], axis=[2, 1])
             else:
                 y = ops.slice(t, start=[2], stop=[3, 4], axis=[2, 1])
Esempio n. 3
0
    def test_identity_numerically(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, axis=0)  # `axis=0` is redundant
            else:
                y = ops.slice(t, axis=0)  # `axis=0` is redundant
            y_host = run_ir(ir, y)

        assert_array_equal(y_host, data)
Esempio n. 4
0
    def test_identity_fn(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, axis=0)  # `axis=0` is redundant
            else:
                y = ops.slice(t, axis=0)  # `axis=0` is redundant

        assert len(ir.main_graph().get_tensors()) == 1
        assert len(ir.main_graph().get_variables()) == 1
Esempio n. 5
0
    def test_start_only(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, start=1)
            else:
                y = ops.slice(t, start=1)
            y_host = run_ir(ir, y)

        y_numpy = data[1:]
        assert_array_equal(y_host, y_numpy)
Esempio n. 6
0
    def test_axis(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, start=[1, 2], stop=[3, 4], axis=[2, 1])
            else:
                y = ops.slice(t, start=[1, 2], stop=[3, 4], axis=[2, 1])
            y_host = run_ir(ir, y)

        y_numpy = data[:, 2:4, 1:3]
        assert_array_equal(y_host, y_numpy)
Esempio n. 7
0
    def test_negative_start(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, start=-2, step=-1)
            else:
                y = ops.slice(t, start=-2, step=-1)
            y_host = run_ir(ir, y)

        y_numpy = data[-2::-1]
        assert_array_equal(y_host, y_numpy)
Esempio n. 8
0
    def test_step(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, start=[1, 3], stop=[3, 1], step=[1, -1])
            else:
                y = ops.slice(t, start=[1, 3], stop=[3, 1], step=[1, -1])
            y_host = run_ir(ir, y)

        y_numpy = data[1:3, 3:1:-1]
        assert_array_equal(y_host, y_numpy)
Esempio n. 9
0
    def test_stop_only_multidim(self, inplace):
        ir = pir.Ir()
        with ir.main_graph():
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, stop=[2, 3])
            else:
                y = ops.slice(t, stop=[2, 3])
            y_host = run_ir(ir, y)

        y_numpy = data[:2, :3]
        assert_array_equal(y_host, y_numpy)
Esempio n. 10
0
    def test_fn_numerically(self, inplace):
        ir = pir.Ir()
        g = ir.main_graph()
        with g:
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, start=1, stop=3, step=1, axis=0)
            else:
                y = ops.slice(t, start=1, stop=3, step=1, axis=0)
            y_host = run_ir(ir, y)

        y_numpy = data[1:3]
        assert_array_equal(y_host, y_numpy)
Esempio n. 11
0
    def test_fn(self, inplace):
        ir = pir.Ir()
        g = ir.main_graph()
        with g:
            t = pir.variable(data)
            if inplace:
                y = ops.slice_(t, start=1, stop=3, step=1, axis=0)
            else:
                y = ops.slice(t, start=1, stop=3, step=1, axis=0)

        if not inplace:
            assert contains_op_of_type("Slice", _ir.op.SliceOp, g)
        else:
            assert contains_op_of_type("SliceInplace", _ir.op.SliceInplaceOp,
                                       g)
        assert len(g.get_tensors()) == 2
        assert len(g.get_variables()) == 1