Ejemplo n.º 1
0
    def __call__(self, x, index):
        if self.dim < 0:
            self.dim += len(x.shape)
        x_range = list(range(len(x.shape)))
        x_range[0] = self.dim
        x_range[self.dim] = 0
        x_swaped = paddle.transpose(x, perm=x_range)
        index_range = list(range(len(index.shape)))
        index_range[0] = self.dim
        index_range[self.dim] = 0
        index_swaped = paddle.transpose(index, perm=index_range)
        dtype = index.dtype

        x_shape = paddle.shape(x_swaped)
        index_shape = paddle.shape(index_swaped)

        prod = paddle.cast(paddle.prod(x_shape), dtype=dtype) / x_shape[0]

        x_swaped_flattend = paddle.flatten(x_swaped)
        index_swaped_flattend = paddle.flatten(index_swaped)
        index_swaped_flattend *= prod

        bias = paddle.arange(start=0, end=prod, dtype=dtype)
        bias = paddle.reshape(bias, x_shape[1:])
        bias = paddle.crop(bias, index_shape[1:])
        bias = paddle.flatten(bias)
        bias = paddle.tile(bias, [index_shape[0]])
        index_swaped_flattend += bias

        gathered = paddle.index_select(x_swaped_flattend, index_swaped_flattend)
        gathered = paddle.reshape(gathered, index_swaped.shape)

        out = paddle.transpose(gathered, perm=x_range)

        return out
Ejemplo n.º 2
0
def gather_op(x, dim, index):

    dtype_mapping = {
        "VarType.INT32": "int32",
        "VarType.INT64": "int64",
        "paddle.int32": "int32",
        "paddle.int64": "int64"
    }
    if dim < 0:
        dim += len(x.shape)

    x_range = list(range(len(x.shape)))
    x_range[0] = dim
    x_range[dim] = 0
    x_swaped = paddle.transpose(x, perm=x_range)

    index_range = list(range(len(index.shape)))
    index_range[0] = dim
    index_range[dim] = 0
    index_swaped = paddle.transpose(index, perm=index_range)

    dtype = dtype_mapping[str(index.dtype)]
    x_shape = paddle.shape(x_swaped)
    index_shape = paddle.shape(index_swaped)
    prod = paddle.prod(x_shape, dtype=dtype) / x_shape[0]

    x_swaped_flattend = paddle.flatten(x_swaped)
    index_swaped_flattend = paddle.flatten(index_swaped)
    index_swaped_flattend *= prod

    bias = paddle.arange(start=0, end=prod, dtype=dtype)
    bias = paddle.reshape(bias, x_shape[1:])
    bias = paddle.crop(bias, index_shape[1:])
    bias = paddle.flatten(bias)
    bias = paddle.tile(bias, [index_shape[0]])

    index_swaped_flattend += bias

    gathered = paddle.index_select(x_swaped_flattend, index_swaped_flattend)
    gathered = paddle.reshape(gathered, index_swaped.shape)

    out = paddle.transpose(gathered, perm=x_range)

    return out
Ejemplo n.º 3
0
 def attr_offsets_value():
     out = paddle.crop(input1,
                       shape=[2, 2, 3, 3],
                       offsets=[0, -1, offset, 0])
Ejemplo n.º 4
0
 def input_dtype():
     out = paddle.crop(input2, shape=[2, 2, 3, 3])
Ejemplo n.º 5
0
 def attr_offsets_dtype():
     out = paddle.crop(input1,
                       shape=[2, 2, 3, 3],
                       offsets=[0, 1.0, 0, 0])
Ejemplo n.º 6
0
 def attr_offsets_type():
     out = paddle.crop(input1, shape=[2, 2, 3, 3], offsets=0)
Ejemplo n.º 7
0
 def attr_shape_value2():
     out = paddle.crop(input1, shape=[2, 0, dim, 3])
Ejemplo n.º 8
0
 def attr_shape_dtype():
     out = paddle.crop(input1, shape=[2, 2.0, 3, 3])
Ejemplo n.º 9
0
 def attr_shape_type():
     out = paddle.crop(input1, shape=3)