Beispiel #1
0
def index_copy(g, self, dim, index, source):
    dim_value = sym_help._parse_arg(dim, "i")
    if sym_help.is_caffe2_aten_fallback():
        return g.at("index_copy", self, index, source, dim_i=dim_value)
    expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(
        g, self, dim, index)
    return scatter(g, self, dim, expanded_index, source)
Beispiel #2
0
def scatter_add(g, self, dim, index, src):
    if symbolic_helper.is_caffe2_aten_fallback():
        return g.at("scatter", self, dim, index, src, overload_name="src")

    src_type = src.type().scalarType()
    src_sizes = symbolic_helper._get_tensor_sizes(src)
    index_sizes = symbolic_helper._get_tensor_sizes(index)

    if src_sizes != index_sizes:
        return symbolic_helper._unimplemented(
            "scatter_add",
            f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
        )

    src = symbolic_helper._maybe_get_scalar(src)
    if symbolic_helper._is_value(src):
        return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
    else:
        # Check if scalar "src" has same type as self (PyTorch allows different
        # type for scalar src (but not when src is tensor)). If not, insert Cast node.
        if self.type().scalarType() != src_type:
            src = g.op(
                "Cast",
                src,
                to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
            )

        return g.op(
            "ScatterElements",
            self,
            index,
            src,
            axis_i=dim,
            reduction_s="add",
        )
Beispiel #3
0
def _dim_arange(g, like, dim):
    like_shape = g.op("Shape", like)
    stop = g.op(
        "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0
    )
    if symbolic_helper.is_caffe2_aten_fallback():
        return g.op("_caffe2::Range", stop)
    return arange(g, stop, 4, None, None, None)
Beispiel #4
0
def index_fill(g, self, dim, index, value):
    dim_value = sym_help._parse_arg(dim, "i")
    if sym_help.is_caffe2_aten_fallback():
        return g.at("index_fill",
                    self,
                    index,
                    value,
                    overload_name="int_Scalar",
                    dim_i=dim_value)

    expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(
        g, self, dim, index)
    value = sym_help._maybe_get_scalar(value)
    value = sym_help._if_scalar_type_as(g, value, self)
    expanded_value = expand(g, value, expanded_index_shape, None)
    return scatter(g, self, dim, expanded_index, expanded_value)
Beispiel #5
0
def index(g, self, index):
    if symbolic_helper.is_caffe2_aten_fallback():
        return g.at("index", self, index, overload_name="Tensor")

    if symbolic_helper._is_packed_list(index):
        indices = symbolic_helper._unpack_list(index)
    else:
        indices = [index]

    # Handle single mask index.
    if len(indices) == 1:
        index = indices[0]
        if not symbolic_helper._is_none(index) and (
            index.type().scalarType() == "Bool" or index.type().scalarType() == "Byte"
        ):
            index = opset9.nonzero(g, index)
            return g.op("GatherND", self, index)
    return opset9.index(g, self, index)
Beispiel #6
0
def scatter(g, self, dim, index, src):
    if symbolic_helper.is_caffe2_aten_fallback():
        return g.at("scatter", self, dim, index, src, overload_name="src")
    src_type = src.type().scalarType()
    src = symbolic_helper._maybe_get_scalar(src)
    if symbolic_helper._is_value(src):
        return g.op("ScatterElements", self, index, src, axis_i=dim)
    else:
        # Check if scalar "src" has same type as self (PyTorch allows different
        # type for scalar src (but not when src is tensor)). If not, insert Cast node.
        if self.type().scalarType() != src_type:
            src = g.op(
                "Cast",
                src,
                to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
            )
        return g.op(
            "ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim
        )
Beispiel #7
0
def index(g, self, index):
    if sym_help.is_caffe2_aten_fallback():
        return g.at("index", self, index, overload_name="Tensor")

    if sym_help._is_packed_list(index):
        indices = sym_help._unpack_list(index)
    else:
        indices = [index]

    # Handle single mask index.
    if len(indices) == 1:
        index = indices[0]
        if not sym_help._is_none(index) and (
                index.type().scalarType() == "Bool"
                or index.type().scalarType() == "Byte"):
            from torch.onnx.symbolic_opset9 import nonzero
            index = nonzero(g, index)
            return g.op("GatherND", self, index)
    from torch.onnx.symbolic_opset9 import index as index_opset9
    return index_opset9(g, self, index)
Beispiel #8
0
def gather(g, self, dim, index, sparse_grad=False):
    if symbolic_helper._maybe_get_const(sparse_grad, "i"):
        return symbolic_helper._unimplemented("gather", "sparse_grad == True")
    if symbolic_helper.is_caffe2_aten_fallback():
        return g.at("gather", self, dim, index, sparse_grad)
    return g.op("GatherElements", self, index, axis_i=dim)
Beispiel #9
0
def index_put(g, self, indices_list_value, values, accumulate=False):
    if symbolic_helper._is_packed_list(indices_list_value):
        indices_list = symbolic_helper._unpack_list(indices_list_value)
    else:
        indices_list = [indices_list_value]
    if symbolic_helper.is_caffe2_aten_fallback():
        args = [self] + indices_list + [values, accumulate]
        return g.at("index_put", *args)

    accumulate = symbolic_helper._parse_arg(accumulate, "b")

    if len(indices_list) == 0:
        return values

    if len(indices_list) > 1:
        for idx_ in range(len(indices_list)):
            if indices_list[idx_].type().scalarType() == "Bool":
                indices_list[idx_] = g.op("NonZero", indices_list[idx_])
        index = indices_list[0]

        for ind in indices_list[1:]:
            index = opset9.add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
            symbolic_helper._unsqueeze_helper(
                g, opset9.expand(g, ind, broadcast_index_shape, None), [-1]
            )
            for ind in indices_list
        ]
        index = g.op("Concat", *indices_list, axis_i=-1)
    else:
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains a single boolean input.
        #
        # index_put -> masked_fill
        #   * input index contains single tensor of Bool type (e.g.: %24 <- %23).
        #   * input value contains single element (e.g.: %18).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        #
        #
        # index_put -> masked_scatter
        #   * input index contains single tensor of Bool type (e.g.: %32 <- %31).
        #   * input value contains multiple elements (e.g.: %28).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        index = indices_list[0]
        bool_inp = index
        if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool":
            rank = symbolic_helper._get_tensor_rank(values)
            if rank is not None and rank == 0:
                return opset9.masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = symbolic_helper._unsqueeze_helper(g, index, [-1])
    sub_data_shape = symbolic_helper._slice_helper(
        g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize]
    )
    values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
    # Check if values is a singular value and expand accordingly
    rank = symbolic_helper._get_tensor_rank(values)
    if rank is not None and rank == 0:
        values = opset9.expand(g, values, values_shape, None)
    values = symbolic_helper._reshape_helper(g, values, values_shape)

    dtype = self.type().scalarType()
    if dtype is not None and dtype != values.type().scalarType():
        values = g.op("Cast", values, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
    dtype = symbolic_helper.scalar_type_to_onnx.index(
        symbolic_helper.cast_pytorch_to_onnx[dtype]
    )
    dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype]

    if accumulate:
        zeros = g.op(
            "ConstantOfShape",
            g.op("Shape", self),
            value_t=torch.tensor([0], dtype=dtype),
        )
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
    else:
        result = g.op("ScatterND", self, index, values)

    return result
Beispiel #10
0
def unfold(g, input, dimension, size, step):
    const_size = sym_help._maybe_get_const(size, "i")
    const_step = sym_help._maybe_get_const(step, "i")
    if not sym_help._is_value(const_size) and not sym_help._is_value(
            const_step):
        from torch.onnx.symbolic_opset9 import unfold as _unfold

        return _unfold(g, input, dimension, const_size, const_step)
    if sym_help.is_caffe2_aten_fallback():
        return g.at("unfold",
                    input,
                    dimension_i=dimension,
                    size_i=size,
                    step_i=step)

    sizedim = sym_help._get_tensor_dim_size(input, dimension)
    if sizedim is not None:
        low_start = g.op("Constant", value_t=torch.tensor(0))
        low_end = g.op("Constant", value_t=torch.tensor(sizedim))
        hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
        low_indices = g.op("Range", low_start, low_end, step)
        hi_indices = g.op("Range", size, hi_end, step)

        low_size = sym_help._size_helper(
            g, low_indices, g.op("Constant", value_t=torch.tensor(0)))
        hi_size = sym_help._size_helper(
            g, hi_indices, g.op("Constant", value_t=torch.tensor(0)))

        ndim = sym_help._get_tensor_rank(input)
        perm = list(range(0, ndim))
        perm.append(perm.pop(dimension))

        unsqueeze_list = []
        loop_condition = g.op("Constant", value_t=torch.tensor(1))
        loop_condition = g.op("Cast", loop_condition, to_i=9)
        loop_len = g.op("Min", low_size, hi_size)
        loop = g.op("Loop", loop_len, loop_condition)

        loop_block = _add_block(loop.node())
        block_input_iter = _add_input_to_block(loop_block)
        cond = _add_input_to_block(loop_block)

        starts = loop_block.op("Gather", low_indices, block_input_iter)
        ends = loop_block.op("Gather", hi_indices, block_input_iter)
        axes = loop_block.op("Constant", value_t=torch.tensor([2]))
        starts = sym_help._unsqueeze_helper(loop_block, starts, [0])
        ends = sym_help._unsqueeze_helper(loop_block, ends, [0])
        stack = loop_block.op("Slice", input, starts, ends, axes)

        unsqueeze = sym_help._unsqueeze_helper(
            loop_block, loop_block.op("Transpose", stack, perm_i=perm),
            [dimension])
        unsqueeze_list.append(unsqueeze)
        concat = loop_block.op("Concat", *unsqueeze_list, axis_i=0)

        cond_out = loop_block.op("Cast", loop_condition, to_i=9)
        _add_output_to_block(loop_block, cond_out)
        _add_output_to_block(loop_block, concat)

        loop_output = loop.node().output()
        perm = [0, 1, 2, 3, 4]
        perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0]
        transpose = g.op("Transpose", loop_output, perm_i=perm)
        squeeze = sym_help._squeeze_helper(g, transpose, [0])

        return squeeze
    else:
        return _unimplemented("Unfold", "input size not accessible")