예제 #1
0
def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight,
                                     reduction):
    from torch.onnx.symbolic_opset9 import sigmoid, log, sub, neg, mul, add
    p = g.op("Constant", value_t=torch.tensor([1]))
    sig_x = sigmoid(g, input)
    log_sig_x = log(g, sig_x)
    sub_1_x = sub(g, p, sig_x)
    sub_1_y = sub(g, p, target)
    log_1_x = log(g, sub_1_x)
    if pos_weight is None or sym_help._is_none(pos_weight):
        output = neg(
            g, add(g, mul(g, target, log_sig_x), mul(g, sub_1_y, log_1_x)))
    else:
        output = neg(
            g,
            add(g, mul(g, mul(g, target, log_sig_x), pos_weight),
                mul(g, sub_1_y, log_1_x)))

    if weight is not None and not sym_help._is_none(weight):
        output = mul(g, weight, output)

    reduction = sym_help._maybe_get_const(reduction, 'i')
    if reduction == 0:
        return output
    elif reduction == 1:
        return g.op("ReduceMean", output)
    elif reduction == 2:
        return g.op("ReduceSum", output)
    else:
        return sym_help._onnx_unsupported(
            "binary_cross_entropy_with_logits with reduction other than none, mean, or sum"
        )
예제 #2
0
def index_put(g, self, indices_list_value, values, accumulate=False):
    indices_list = sym_help._unpack_list(indices_list_value)
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
            g.op("Unsqueeze",
                 expand(g, ind, broadcast_index_shape, None),
                 axes_i=[-1]) for ind in indices_list
        ]
        index = g.op("Concat", *indices_list, axis_i=-1)
    else:
        broadcast_index_shape = g.op("Shape", index)
        index = g.op("Unsqueeze", index, axes_i=[-1])
    sub_data_shape = sym_help._slice_helper(g,
                                            g.op("Shape", self),
                                            axes=[0],
                                            starts=[len(indices_list)],
                                            ends=[maxsize])
    values_shape = g.op("Concat",
                        broadcast_index_shape,
                        sub_data_shape,
                        axis_i=0)
    values = g.op("Reshape", values, values_shape)

    if accumulate:
        dtype = self.type().scalarType()
        dtype = sym_help.scalar_type_to_onnx.index(
            sym_help.cast_pytorch_to_onnx[dtype])
        dtype = sym_help.scalar_type_to_pytorch_type[dtype]
        zeros = g.op("ConstantOfShape",
                     g.op("Shape", self),
                     value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
    else:
        result = g.op("ScatterND", self, index, values)

    return result
예제 #3
0
    def add(g, x, y, op_scale, op_zero_point):
        x, _, _, _ = sym_help.dequantize_helper(g, x)
        y, _, _, _ = sym_help.dequantize_helper(g, y)

        output = add(g, x, y)

        return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
예제 #4
0
def group_norm_symbolic(g, input, num_groups, weight, bias, eps,
                        cudnn_enabled):
    from torch.onnx.symbolic_opset9 import reshape, mul, add, reshape_as

    channels_num = input.type().sizes()[1]

    if num_groups == channels_num:
        output = g.op('InstanceNormalization',
                      input,
                      weight,
                      bias,
                      epsilon_f=eps)
    else:
        # Reshape from [n, g * cg, h, w] to [1, n * g, cg * h, w].
        x = reshape(g, input, [0, num_groups, -1, 0])
        x = reshape(g, x, [1, -1, 0, 0])
        # Normalize channel-wise.
        x = g.op('MeanVarianceNormalization', x, axes_i=[2, 3])
        # Reshape back.
        x = reshape_as(g, x, input)
        # Apply affine transform.
        x = mul(g, x, reshape(g, weight, [1, channels_num, 1, 1]))
        output = add(g, x, reshape(g, bias, [1, channels_num, 1, 1]))

    return output
예제 #5
0
def normal(g, loc, scale, seed):
    # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a
    # scale-location transformation of that distribution, which has mean μ and variance σ's square. If x is a sample
    # from a mean 0 and variance 1 distribution then
    #       σx+μ
    # is a sample with mean μ and variance σ's square.
    result = mul(g, scale, g.op("RandomNormalLike", loc))
    return add(g, result, loc)
예제 #6
0
def full(g, sizes, value, dtype, layout, device, pin_memory=False):
    const_value = symbolic_helper._maybe_get_const(value, "t")
    if symbolic_helper._is_value(const_value):
        tmp = zeros(g, sizes, dtype, layout, device)
        return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
    else:
        dtype = symbolic_helper._get_const(dtype, "i", "dtype")
        return _constant_fill(g, sizes, dtype, const_value)
예제 #7
0
    def add(g, x, y, op_scale, op_zero_point):
        x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
        y, _, _, _ = symbolic_helper.dequantize_helper(g, y)

        output = opset9.add(g, x, y)

        return symbolic_helper.quantize_helper(g, output, op_scale,
                                               op_zero_point)
예제 #8
0
def full(g, sizes, value, dtype, layout, device, pin_memory=False):
    const_value = sym_help._maybe_get_const(value, 't')
    if sym_help._is_value(const_value):
        tmp = zeros(g, sizes, dtype, layout, device)
        return sym_opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
    else:
        dtype = sym_help._get_const(dtype, 'i', 'dtype')
        return _constant_fill(g, sizes, dtype, const_value)
예제 #9
0
def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight,
                                     reduction):
    p = g.op("Constant", value_t=torch.tensor([1]))
    sig_x = opset9.sigmoid(g, input)
    log_sig_x = opset9.log(g, sig_x)
    sub_1_x = opset9.sub(g, p, sig_x)
    sub_1_y = opset9.sub(g, p, target)
    log_1_x = opset9.log(g, sub_1_x)
    if pos_weight is None or symbolic_helper._is_none(pos_weight):
        output = opset9.neg(
            g,
            opset9.add(g, opset9.mul(g, target, log_sig_x),
                       opset9.mul(g, sub_1_y, log_1_x)),
        )
    else:
        output = opset9.neg(
            g,
            opset9.add(
                g,
                opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight),
                opset9.mul(g, sub_1_y, log_1_x),
            ),
        )

    if weight is not None and not symbolic_helper._is_none(weight):
        output = opset9.mul(g, weight, output)

    reduction = symbolic_helper._maybe_get_const(reduction, "i")
    if reduction == 0:
        return output
    elif reduction == 1:
        return g.op("ReduceMean", output, keepdims_i=0)
    elif reduction == 2:
        return g.op("ReduceSum", output, keepdims_i=0)
    else:
        return symbolic_helper._onnx_unsupported(
            "binary_cross_entropy_with_logits with reduction other than none, mean, or sum",
            input,
        )
예제 #10
0
def add(g, self, other, alpha=None):
    if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self):
        tensor_list_node = other.node()
        if tensor_list_node.kind() != "prim::ListConstruct":
            return symbolic_helper._unimplemented(
                "add", "does not support adding dynamic tensor list to another"
            )
        tensors = symbolic_helper._unpack_list(other)
        l = self
        for t in tensors:
            l = g.op("SequenceInsert", l, t)
        return l

    return opset9.add(g, self, other, alpha)
예제 #11
0
def addcmul_symbolic(g, self, tensor1, tensor2, value=1, out=None):
    from torch.onnx.symbolic_opset9 import add, mul

    if out is not None:
        sym_help._unimplemented("addcmul",
                                "Out parameter is not supported for addcmul")

    x = mul(g, tensor1, tensor2)
    value = sym_help._maybe_get_scalar(value)
    if sym_help._scalar(value) != 1:
        value = sym_help._if_scalar_type_as(g, value, x)
        if not sym_help._is_value(value):
            value = g.op("Constant",
                         value_t=torch.tensor(value, dtype=torch.float32))
        x = mul(g, x, value)
    return add(g, self, x)
def index_put(g, self, indices_list_value, values, accumulate=False):
    if sym_help._is_packed_list(indices_list_value):
        indices_list = sym_help._unpack_list(indices_list_value)
    else:
        indices_list = [indices_list_value]
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    if len(indices_list) == 0:
        return values

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
            sym_help._unsqueeze_helper(
                g, expand(g, ind, broadcast_index_shape, None), [-1])
            for ind in indices_list
        ]
        index = g.op("Concat", *indices_list, axis_i=-1)
    else:
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains boolean inputs
        #
        # index_put -> masked_fill
        #   * input index contains single tensor of Bool type (e.g.: %24 <- %23).
        #   * input value contains single element (e.g.: %18).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        #
        #
        # index_put -> masked_scatter
        #   * input index contains single tensor of Bool type (e.g.: %32 <- %31).
        #   * input value contains multiple elements (e.g.: %28).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        bool_inp = index
        if bool_inp.type() is not None and bool_inp.type().scalarType(
        ) == 'Bool':
            rank = sym_help._get_tensor_rank(values)
            if rank is not None and rank == 0:
                from torch.onnx.symbolic_opset9 import masked_fill
                return masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = sym_help._unsqueeze_helper(g, index, [-1])
    sub_data_shape = sym_help._slice_helper(g,
                                            g.op("Shape", self),
                                            axes=[0],
                                            starts=[len(indices_list)],
                                            ends=[maxsize])
    values_shape = g.op("Concat",
                        broadcast_index_shape,
                        sub_data_shape,
                        axis_i=0)
    # Check if values is a singular value and expand accordingly
    rank = sym_help._get_tensor_rank(values)
    if rank is not None and rank == 0:
        values = expand(g, values, values_shape, None)
    values = g.op("Reshape", values, values_shape)

    dtype = self.type().scalarType()
    if dtype is not None and dtype != values.type().scalarType():
        values = g.op("Cast",
                      values,
                      to_i=sym_help.cast_pytorch_to_onnx[dtype])
    dtype = sym_help.scalar_type_to_onnx.index(
        sym_help.cast_pytorch_to_onnx[dtype])
    dtype = sym_help.scalar_type_to_pytorch_type[dtype]

    if accumulate:
        zeros = g.op("ConstantOfShape",
                     g.op("Shape", self),
                     value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
    else:
        result = g.op("ScatterND", self, index, values)

    return result
예제 #13
0
def index_put(g, self, indices_list_value, values, accumulate=False):
    indices_list = sym_help._unpack_list(indices_list_value)
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
            sym_help._unsqueeze_helper(
                g, expand(g, ind, broadcast_index_shape, None), [-1])
            for ind in indices_list
        ]
        index = g.op("Concat", *indices_list, axis_i=-1)
    else:
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains boolean inputs
        #
        # index_put -> masked_fill
        #
        # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %6 : None = prim::Constant()
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %8 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::ne(%mask, %some_const)
        #   %26 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]()
        #   %27 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %11 : Device = prim::Constant[value="cpu"]()
        #   %12 : None = prim::Constant()
        #   %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %15 : None = prim::Constant()
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %22 : int[] = prim::Constant[value=[-1]]()
        #   %23 : Tensor = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        #
        # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %3 : Tensor = onnx::Equal(%0, %some_const)
        #   %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3)
        #   %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4)
        #   %19 : Tensor = onnx::Cast[to=9](%12)
        #   %20 : Tensor = onnx::Constant[value={1}]()
        #   %21 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = onnx::Where(%19, %20, %0)
        #   return (%21)
        #
        # index_put -> masked_scatter
        #
        # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %6 : None = prim::Constant()
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %34 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]()
        #   %35 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %18 : Device = prim::Constant[value="cpu"]()
        #   %19 : None = prim::Constant()
        #   %36 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %22 : None = prim::Constant()
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Tensor = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        #
        # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %3 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #               = onnx::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %4 : Tensor = onnx::Equal(%0, %some_const)
        #   %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4)
        #   %13 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%5)
        #   %19 : Tensor = onnx::Shape(%0)
        #   %20 : Tensor = onnx::Expand(%13, %19)
        #   %21 : Tensor = onnx::NonZero(%20)
        #   %22 : Tensor = onnx::Transpose[perm=[1, 0]](%21)
        #   %23 : Tensor = onnx::Constant[value={-1}]()
        #   %24 : Tensor = onnx::Reshape(%3, %23)
        #   %25 : Tensor = onnx::Shape(%22)
        #   %27 : Tensor = onnx::Constant[value={0}]()
        #   %28 : Tensor = onnx::Gather[axis=0](%25, %27)
        #   %29 : Tensor = onnx::Constant[value={0}]()
        #   %30 : Tensor = onnx::Unsqueeze[axes=[0]](%29)
        #   %31 : Tensor = onnx::Unsqueeze[axes=[0]](%28)
        #   %32 : Tensor = onnx::Constant[value={0}]()
        #   %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32)
        #   %34 : Tensor = onnx::Slice(%24, %30, %31, %33)
        #   %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = onnx::ScatterND(%0, %22, %34)
        #   return (%35)

        bool_inp = list(index.node().inputs())[0]
        if bool_inp.type() is not None and bool_inp.type().scalarType(
        ) == 'Bool':
            rank = sym_help._get_tensor_rank(values)
            if rank is not None and rank == 0:
                from torch.onnx.symbolic_opset9 import masked_fill
                return masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = sym_help._unsqueeze_helper(g, index, [-1])
    sub_data_shape = sym_help._slice_helper(g,
                                            g.op("Shape", self),
                                            axes=[0],
                                            starts=[len(indices_list)],
                                            ends=[maxsize])
    values_shape = g.op("Concat",
                        broadcast_index_shape,
                        sub_data_shape,
                        axis_i=0)
    values = g.op("Reshape", values, values_shape)

    if accumulate:
        dtype = self.type().scalarType()
        dtype = sym_help.scalar_type_to_onnx.index(
            sym_help.cast_pytorch_to_onnx[dtype])
        dtype = sym_help.scalar_type_to_pytorch_type[dtype]
        zeros = g.op("ConstantOfShape",
                     g.op("Shape", self),
                     value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
    else:
        result = g.op("ScatterND", self, index, values)

    return result
예제 #14
0
def multiclass_nms_core_symbolic(g,
                                 multi_bboxes,
                                 multi_scores,
                                 score_thr,
                                 nms_cfg,
                                 max_num=-1):

    from torch.onnx.symbolic_opset9 import reshape, squeeze
    from torch.onnx.symbolic_opset10 import _slice

    def cast(x, dtype):
        return g.op('Cast', x, to_i=sym_help.cast_pytorch_to_onnx[dtype])

    def get_size(x, dim):
        shape = g.op('Shape', x)
        dim = _slice(g, shape, axes=[0], starts=[dim], ends=[dim + 1])
        return cast(dim, 'Long')

    nms_op_type = nms_cfg.get('type', 'nms')
    assert nms_op_type == 'nms'
    assert 'iou_thr' in nms_cfg
    iou_threshold = nms_cfg['iou_thr']
    assert 0 <= iou_threshold <= 1

    # Transpose and reshape input tensors to fit ONNX NonMaxSuppression.
    multi_bboxes = reshape(g, multi_bboxes, [0, -1, 4])
    multi_bboxes = g.op('Transpose', multi_bboxes, perm_i=[1, 0, 2])

    batches_num = get_size(multi_bboxes, 0)
    spatial_num = get_size(multi_bboxes, 1)

    multi_scores = g.op('Transpose', multi_scores, perm_i=[1, 0])
    scores_shape = g.op('Concat',
                        batches_num,
                        g.op('Constant', value_t=torch.LongTensor([-1])),
                        spatial_num,
                        axis_i=0)
    multi_scores = reshape(g, multi_scores, scores_shape)
    classes_num = get_size(multi_scores, 1)

    assert max_num > 0

    indices = g.op(
        'NonMaxSuppression', multi_bboxes, multi_scores,
        g.op('Constant', value_t=torch.LongTensor([max_num])),
        g.op('Constant', value_t=torch.FloatTensor([iou_threshold])),
        g.op('Constant', value_t=torch.FloatTensor([score_thr])))

    # Flatten bboxes and scores.
    multi_bboxes_flat = reshape(g, multi_bboxes, [-1, 4])
    multi_scores_flat = reshape(g, multi_scores, [
        -1,
    ])

    # Flatten indices.
    batch_indices = _slice(g, indices, axes=[1], starts=[0], ends=[1])
    class_indices = _slice(g, indices, axes=[1], starts=[1], ends=[2])
    box_indices = _slice(g, indices, axes=[1], starts=[2], ends=[3])

    def add(*args, dtype='Long'):
        x = g.op('Add', args[0], args[1])
        if dtype is not None:
            x = cast(x, dtype)
        return x

    def mul(*args, dtype='Long'):
        x = g.op('Mul', args[0], args[1])
        if dtype is not None:
            x = cast(x, dtype)
        return x

    flat_box_indices = add(mul(batch_indices, spatial_num), box_indices)
    flat_score_indices = add(
        mul(add(mul(batch_indices, classes_num), class_indices), spatial_num),
        box_indices)

    # Select bboxes.
    out_bboxes = reshape(
        g, g.op('Gather', multi_bboxes_flat, flat_box_indices, axis_i=0),
        [-1, 4])
    out_scores = reshape(
        g, g.op('Gather', multi_scores_flat, flat_score_indices, axis_i=0),
        [-1, 1])
    # Having either batch size or number of classes here equal to one is the limitation of implementation.
    class_indices = reshape(g, cast(add(class_indices, batch_indices),
                                    'Float'), [-1, 1])

    # Combine bboxes, scores and labels into a single tensor.
    # This a workaround for a PyTorch bug (feature?),
    # limiting ONNX operations to output only single tensor.
    out_combined_bboxes = g.op('Concat',
                               out_bboxes,
                               out_scores,
                               class_indices,
                               axis_i=1)

    # Get the top scored bboxes only.
    elements_num = sym_help._size_helper(g,
                                         out_scores,
                                         dim=g.op('Constant',
                                                  value_t=torch.LongTensor(
                                                      [0])))
    max_num = g.op('Constant', value_t=torch.LongTensor([max_num]))
    if sym_help._export_onnx_opset_version < 12:
        kn = g.op('Concat', max_num, elements_num, axis_i=0)
        kn = g.op('ReduceMin', kn, keepdims_i=0)
    else:
        kn = g.op('Min', max_num, elements_num)
    _, top_indices = sym_help._topk_helper(g, out_scores, kn, dim=0)
    # top_indices = squeeze(g, top_indices, dim=1)
    top_indices = reshape(g, top_indices, [
        -1,
    ])
    out_combined_bboxes = g.op('Gather',
                               out_combined_bboxes,
                               top_indices,
                               axis_i=0)

    return out_combined_bboxes
예제 #15
0
def index_put(g, self, indices_list_value, values, accumulate=False):
    if symbolic_helper._is_packed_list(indices_list_value):
        indices_list = symbolic_helper._unpack_list(indices_list_value)
    else:
        indices_list = [indices_list_value]
    if symbolic_helper.is_caffe2_aten_fallback():
        args = [self] + indices_list + [values, accumulate]
        return g.at("index_put", *args)

    accumulate = symbolic_helper._parse_arg(accumulate, "b")

    if len(indices_list) == 0:
        return values

    if len(indices_list) > 1:
        for idx_ in range(len(indices_list)):
            if indices_list[idx_].type().scalarType() == "Bool":  # type: ignore[attr-defined]
                # TODO(justinchuby): Remove type ignore after #81112 is checked in.
                indices_list[idx_] = g.op("NonZero", indices_list[idx_])
        index = indices_list[0]

        for ind in indices_list[1:]:
            index = opset9.add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
            symbolic_helper._unsqueeze_helper(
                g, opset9.expand(g, ind, broadcast_index_shape, None), [-1]
            )
            for ind in indices_list
        ]
        index = g.op("Concat", *indices_list, axis_i=-1)
    else:
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains a single boolean input.
        #
        # index_put -> masked_fill
        #   * input index contains single tensor of Bool type (e.g.: %24 <- %23).
        #   * input value contains single element (e.g.: %18).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        #
        #
        # index_put -> masked_scatter
        #   * input index contains single tensor of Bool type (e.g.: %32 <- %31).
        #   * input value contains multiple elements (e.g.: %28).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        index = indices_list[0]
        bool_inp = index
        if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool":  # type: ignore[attr-defined]
            # TODO(justinchuby): Remove type ignore after #81112 is checked in.
            rank = symbolic_helper._get_tensor_rank(values)
            if rank is not None and rank == 0:
                return opset9.masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = symbolic_helper._unsqueeze_helper(g, index, [-1])
    sub_data_shape = symbolic_helper._slice_helper(
        g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize]
    )
    values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
    # Check if values is a singular value and expand accordingly
    rank = symbolic_helper._get_tensor_rank(values)
    if rank is not None and rank == 0:
        values = opset9.expand(g, values, values_shape, None)
    values = symbolic_helper._reshape_helper(g, values, values_shape)

    dtype = self.type().scalarType()
    if dtype is not None and dtype != values.type().scalarType():
        values = g.op("Cast", values, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
    dtype = symbolic_helper.scalar_type_to_onnx.index(
        symbolic_helper.cast_pytorch_to_onnx[dtype]
    )
    dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype]

    if accumulate:
        zeros = g.op(
            "ConstantOfShape",
            g.op("Shape", self),
            value_t=torch.tensor([0], dtype=dtype),
        )
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
    else:
        result = g.op("ScatterND", self, index, values)

    return result