def __getitem__(self, key) -> 'Tensor': """ Supports slicing and integer indexing. If a single index is selected during slicing the dimention will be squeezed - this matches numpy slicing rules. Examples: ``` # Slicing x[0] # Select all elements where i==0 for axis 0 x[0,1] # Select all elements where i==0, j==1 for axis 0 and 1 x[0:2] # Slice axis 0 between index 0 and 2 x[:2,3:] # Slice axis 0 upto 2 and axis 1 from index 3 x[:,::-1] # Select all elements for axis 0 and reverse axis 1 # Integer indexing indices = Tensor([[0,1], [1,0]], dtype='int32') x[indices] # Select elements [0,1] and [1,0] from `x` ``` """ import popart.ir.ops as ops if isinstance(key, (slice, int)) or (isinstance(key, tuple) and all( isinstance(e, (slice, int)) for e in key)): # Basic slicing (integer or slices) key = (key, ) if isinstance(key, (slice, int)) else key start = [] stop = [] step = [] int_slices = [] for i, key_i in enumerate(key): if isinstance(key_i, int): start += [key_i] stop += [key_i + 1] step += [1] int_slices += [i] elif isinstance(key_i, slice): start += [key_i.start] stop += [key_i.stop] step += [key_i.step] out = ops.slice(self, start, stop, step) if len(int_slices) > 0: out = ops.squeeze(out, axes=int_slices) return out elif (isinstance(key, Tensor) and key.dtype.is_int): # Integer indexing return ops.gather(self, key) else: raise IndexError( "Only integers, slices (`:`) and integer arrays are valid indices." )
def test_error_lengths(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) with pytest.raises(ValueError): if inplace: y = ops.slice_(t, start=[2], stop=[3, 4], axis=[2, 1]) else: y = ops.slice(t, start=[2], stop=[3, 4], axis=[2, 1])
def test_identity_numerically(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, axis=0) # `axis=0` is redundant else: y = ops.slice(t, axis=0) # `axis=0` is redundant y_host = run_ir(ir, y) assert_array_equal(y_host, data)
def test_identity_fn(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, axis=0) # `axis=0` is redundant else: y = ops.slice(t, axis=0) # `axis=0` is redundant assert len(ir.main_graph().get_tensors()) == 1 assert len(ir.main_graph().get_variables()) == 1
def test_start_only(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, start=1) else: y = ops.slice(t, start=1) y_host = run_ir(ir, y) y_numpy = data[1:] assert_array_equal(y_host, y_numpy)
def test_axis(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, start=[1, 2], stop=[3, 4], axis=[2, 1]) else: y = ops.slice(t, start=[1, 2], stop=[3, 4], axis=[2, 1]) y_host = run_ir(ir, y) y_numpy = data[:, 2:4, 1:3] assert_array_equal(y_host, y_numpy)
def test_negative_start(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, start=-2, step=-1) else: y = ops.slice(t, start=-2, step=-1) y_host = run_ir(ir, y) y_numpy = data[-2::-1] assert_array_equal(y_host, y_numpy)
def test_step(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, start=[1, 3], stop=[3, 1], step=[1, -1]) else: y = ops.slice(t, start=[1, 3], stop=[3, 1], step=[1, -1]) y_host = run_ir(ir, y) y_numpy = data[1:3, 3:1:-1] assert_array_equal(y_host, y_numpy)
def test_stop_only_multidim(self, inplace): ir = pir.Ir() with ir.main_graph(): t = pir.variable(data) if inplace: y = ops.slice_(t, stop=[2, 3]) else: y = ops.slice(t, stop=[2, 3]) y_host = run_ir(ir, y) y_numpy = data[:2, :3] assert_array_equal(y_host, y_numpy)
def test_fn_numerically(self, inplace): ir = pir.Ir() g = ir.main_graph() with g: t = pir.variable(data) if inplace: y = ops.slice_(t, start=1, stop=3, step=1, axis=0) else: y = ops.slice(t, start=1, stop=3, step=1, axis=0) y_host = run_ir(ir, y) y_numpy = data[1:3] assert_array_equal(y_host, y_numpy)
def test_fn(self, inplace): ir = pir.Ir() g = ir.main_graph() with g: t = pir.variable(data) if inplace: y = ops.slice_(t, start=1, stop=3, step=1, axis=0) else: y = ops.slice(t, start=1, stop=3, step=1, axis=0) if not inplace: assert contains_op_of_type("Slice", _ir.op.SliceOp, g) else: assert contains_op_of_type("SliceInplace", _ir.op.SliceInplaceOp, g) assert len(g.get_tensors()) == 2 assert len(g.get_variables()) == 1