Example #1
0
 def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
              pool_mode, aligned):
     from torch.onnx.symbolic_opset9 import sub, squeeze
     from torch.onnx.symbolic_helper import _slice_helper
     from torch.onnx import TensorProtoDataType
     # batch_indices = rois[:, 0].long()
     batch_indices = _slice_helper(g, rois, axes=[1], starts=[0], ends=[1])
     batch_indices = squeeze(g, batch_indices, 1)
     batch_indices = g.op('Cast',
                          batch_indices,
                          to_i=TensorProtoDataType.INT64)
     # rois = rois[:, 1:]
     rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
     if aligned:
         # rois -= 0.5/spatial_scale
         aligned_offset = g.op('Constant',
                               value_t=torch.tensor([0.5 / spatial_scale],
                                                    dtype=torch.float32))
         rois = sub(g, rois, aligned_offset)
     # roi align
     return g.op('RoiAlign',
                 input,
                 rois,
                 batch_indices,
                 output_height_i=output_size[0],
                 output_width_i=output_size[1],
                 spatial_scale_f=spatial_scale,
                 sampling_ratio_i=max(0, sampling_ratio),
                 mode_s=pool_mode)
Example #2
0
def roll(g, input, shifts, dims):
    assert len(shifts) == len(dims)
    result = input
    for i in range(len(shifts)):
        shapes = []
        shape = _slice_helper(g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[maxsize])
        shapes.append(shape)
        shape = _slice_helper(g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]])
        shapes.append(shape)
        result = g.op("Concat", *shapes, axis_i=dims[i])
    return result
Example #3
0
    def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
                 pool_mode, aligned):
        has_custom_op = False
        try:
            import os.path as osp

            from mmcv.ops import get_onnxruntime_op_path
            ort_op_path = get_onnxruntime_op_path()
            has_custom_op = osp.exists(ort_op_path)
        except ImportError:
            pass
        if has_custom_op:
            return g.op(
                'mmcv::MMCVRoiAlign',
                input,
                rois,
                aligned_height_i=output_size[0],
                aligned_width_i=output_size[1],
                spatial_scale_f=spatial_scale,
                sampling_ratio_i=max(0, sampling_ratio),
                pool_mode_s=pool_mode,
                aligned_i=aligned)

        from torch.onnx.symbolic_opset9 import sub, squeeze
        from torch.onnx.symbolic_helper import _slice_helper
        from torch.onnx import TensorProtoDataType
        # batch_indices = rois[:, 0].long()
        batch_indices = _slice_helper(g, rois, axes=[1], starts=[0], ends=[1])
        batch_indices = squeeze(g, batch_indices, 1)
        batch_indices = g.op(
            'Cast', batch_indices, to_i=TensorProtoDataType.INT64)
        # rois = rois[:, 1:]
        rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
        if aligned:
            # rois -= 0.5/spatial_scale
            aligned_offset = g.op(
                'Constant',
                value_t=torch.tensor([0.5 / spatial_scale],
                                     dtype=torch.float32))
            rois = sub(g, rois, aligned_offset)
        # roi align
        return g.op(
            'RoiAlign',
            input,
            rois,
            batch_indices,
            output_height_i=output_size[0],
            output_width_i=output_size[1],
            spatial_scale_f=spatial_scale,
            sampling_ratio_i=max(0, sampling_ratio),
            mode_s=pool_mode)
Example #4
0
    def symbolic_fn(g, input, output_size, *args):
        scales, align_corners = sym_help._get_interpolate_attributes(g, interpolate_mode, args)
        align_corners = sym_help._maybe_get_scalar(align_corners)
        coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \
            else "align_corners" if align_corners else "pytorch_half_pixel"
        empty_tensor = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))

        if scales is None:
            input_size = g.op("Shape", input)
            input_size_beg = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0])
            output_size = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Long"])
            output_size = g.op("Concat", input_size_beg, output_size, axis_i=0)
            scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
            return g.op("Resize",
                        input,
                        empty_tensor,  # roi only takes effect whith coordinate_transformation_mode="tf_crop_and_resize"
                        scales,  # scales is not needed since we are sending out_size
                        output_size,
                        coordinate_transformation_mode_s=coordinate_transformation_mode,
                        cubic_coeff_a_f=-0.75,  # only valid when mode="cubic"
                        mode_s=interpolate_mode,  # nearest, linear, or cubic
                        nearest_mode_s="floor")  # only valid when mode="nearest"
        else:
            return g.op("Resize",
                        input,
                        empty_tensor,  # roi only takes effect with coordinate_transformation_mode="tf_crop_and_resize"
                        scales,  # scales is not needed since we are sending out_size
                        coordinate_transformation_mode_s=coordinate_transformation_mode,
                        cubic_coeff_a_f=-0.75,  # only valid when mode="cubic"
                        mode_s=interpolate_mode,  # nearest, linear, or cubic
                        nearest_mode_s="floor")  # only valid when mode="nearest"
Example #5
0
    def symbolic_fn(g, input, output_size, align_corners=None):
        if align_corners:
            return _unimplemented(name, "align_corners == True")

        output_size = sym_help._maybe_get_const(output_size, 'is')
        if sym_help._is_value(output_size):
            offset = 2
            offsets = g.op("Constant",
                           value_t=torch.tensor([1. for i in range(offset)]))
            dividend = g.op("Cast",
                            output_size,
                            to_i=sym_help.cast_pytorch_to_onnx["Float"])
            divisor = sym_help._slice_helper(g,
                                             g.op("Shape", input),
                                             axes=[0],
                                             ends=[dim],
                                             starts=[offset])
            divisor = g.op("Cast",
                           divisor,
                           to_i=sym_help.cast_pytorch_to_onnx["Float"])
            scale_dims = g.op("Div", dividend, divisor)
            scales = g.op("Concat", offsets, scale_dims, axis_i=0)
        else:
            scales_constant = [
                1. if i < 2 else float(output_size[-(dim - i)]) /
                float(input.type().sizes()[-(dim - i)]) for i in range(0, dim)
            ]
            scales = g.op("Constant", value_t=torch.tensor(scales_constant))
        return g.op("Resize", input, scales, mode_s=interpolate_mode)
Example #6
0
def slice(g, self, *args):
    if len(args) == 4:
        # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
        dim, start, end, step = args
    elif len(args) == 3:
        # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[]
        start, end, step = args
        dim = 0
    else:
        raise NotImplementedError("Unknown aten::slice signature")
    is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == "NoneType"
    is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType"
    is_start_onnx_const = start.node().kind() == "onnx::Constant"
    is_end_onnx_const = end.node().kind() == "onnx::Constant"
    step = sym_help._parse_arg(step, "i")
    if (not is_start_none and not is_start_onnx_const) or \
       (not isinstance(end, int) and not is_end_none and not is_end_onnx_const) or \
       (not isinstance(dim, int) and dim.node().kind() != "onnx::Constant"):
        dynamic_slice = True
        if is_start_none:
            start = g.op("Constant", value_t=torch.tensor(0))
        if is_end_none:
            end = g.op("Constant", value_t=torch.tensor(9223372036854775807))
    else:
        start = [0 if is_start_none else sym_help._parse_arg(start, "i")]
        end = [9223372036854775807 if is_end_none else sym_help._parse_arg(end, "i")]
        dim = [sym_help._parse_arg(dim, "i")]
        dynamic_slice = False
    return sym_help._slice_helper(g, self, axes=dim, starts=start, ends=end, steps=[step], dynamic_slice=dynamic_slice)
 def symbolic_fn(g, input, output_size, align_corners=None):
     align_corners = sym_help._maybe_get_scalar(align_corners)
     coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \
         else "align_corners" if align_corners else "pytorch_half_pixel"
     empty_tensor = g.op("Constant",
                         value_t=torch.tensor([], dtype=torch.float32))
     input_size = g.op("Shape", input)
     input_size_beg = sym_help._slice_helper(g,
                                             input_size,
                                             axes=[0],
                                             ends=[2],
                                             starts=[0])
     output_size = g.op("Cast",
                        output_size,
                        to_i=sym_help.cast_pytorch_to_onnx["Long"])
     output_size = g.op("Concat", input_size_beg, output_size, axis_i=0)
     if interpolate_mode == "bilinear":
         return g.op("Plugin",
                     input,
                     name_s="Bilinear",
                     info_s=json.dumps({"scale_factor": [2, 2]}))
     else:
         return g.op("Upsample",
                     input,
                     scales_f=[1, 2, 2],
                     mode_s=interpolate_mode)
Example #8
0
    def symbolic_fn(g, input, output_size, align_corners=None):
        align_corners = sym_help._maybe_get_scalar(align_corners)
        coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \
            else "align_corners" if align_corners else "pytorch_half_pixel"
        empty_tensor = g.op("Constant",
                            value_t=torch.tensor([], dtype=torch.float32))
        input_size = g.op("Shape", input)
        input_size_beg = sym_help._slice_helper(g,
                                                input_size,
                                                axes=[0],
                                                ends=[2],
                                                starts=[0])
        output_size = g.op("Cast",
                           output_size,
                           to_i=sym_help.cast_pytorch_to_onnx["Long"])
        output_size = g.op("Concat", input_size_beg, output_size, axis_i=0)

        # ************************************ upsample change
        if interpolate_mode == "bilinear":
            # https://github.com/dlunion/tensorRTIntegrate/blob/7543614a4ff5f6e022526a4e137d47f4e63f5f96/plugin_onnx_export.py#L33
            return g.op("Plugin",
                        input,
                        name_s="Bilinear",
                        info_s=json.dumps({"scale_factor": [2, 2]}))
        else:
            return g.op("Upsample",
                        input,
                        scales_f=[1, 2, 2],
                        mode_s=interpolate_mode)
Example #9
0
def flip(g, input, dims):
    return sym_help._slice_helper(g,
                                  input,
                                  axes=dims,
                                  starts=[-1] * len(dims),
                                  ends=[-9223372036854775807] * len(dims),
                                  steps=[-1] * len(dims))
Example #10
0
def slice(g, self, *args):
    if len(args) == 4:
        # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor
        dim, start, end, step = args
    elif len(args) == 3:
        # aten::slice(t[] l, int start, int end, int step) -> t[]
        start, end, step = args
        dim = 0
    else:
        raise NotImplementedError("Unknown aten::slice signature")

    step = sym_help._parse_arg(step, 'i')
    if (start.node().kind() != 'onnx::Constant' or
        (not isinstance(end, int) and end.node().kind() != 'onnx::Constant') or
        (not isinstance(dim, int) and dim.node().kind() != 'onnx::Constant')):
        dynamic_slice = True
    else:
        start = [sym_help._parse_arg(start, 'i')]
        end = [sym_help._parse_arg(end, 'i')]
        dim = [sym_help._parse_arg(dim, 'i')]
        dynamic_slice = False
    return sym_help._slice_helper(g,
                                  self,
                                  axes=dim,
                                  starts=start,
                                  ends=end,
                                  steps=[step],
                                  dynamic_slice=dynamic_slice)
Example #11
0
def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_scale_factor):
    mode = sym_help._maybe_get_const(mode, 's')
    if 'linear' in mode:
        mode = 'linear'
    if 'cubic' in mode:
        mode = 'cubic'
    align_corners = sym_help._maybe_get_const(align_corners, 'b')
    align_corners = False if not isinstance(align_corners, bool) else align_corners
    coordinate_transformation_mode = "asymmetric" if mode == "nearest" \
        else "align_corners" if align_corners else "pytorch_half_pixel"
    # roi only takes effect with coordinate_transformation_mode="tf_crop_and_resize"
    roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))

    if not sym_help._is_none(size) :
        input_size = g.op("Shape", input)
        input_size = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0])
        # in some cases size is not a packed list but size is a scalar
        # We need to also verify that (sym_help._maybe_get_const(size, 't').dim() == 0)
        # but this information is not always available. Try to get the dim,
        # and if not assume that it is not a scalar.
        try:
            is_scalar = not sym_help._is_packed_list(size) and ((sym_help._maybe_get_const(size, 't').dim() == 0))
        except AttributeError:
            is_scalar = not sym_help._is_packed_list(size)
            if not is_scalar:
                warnings.warn("Cannot verify if the output_size is a scalar "
                              "while exporting interpolate. Assuming that it is not a scalar.")

        if is_scalar:
            if not input.type().dim():
                return sym_help._unimplemented("interpolate (with a scalar output_size)",
                                               "missing input shape (try giving an array of output_size values)")
            size = unsqueeze(g, size, 0)
            size = [size for i in range(input.type().dim() - 2)]
            size = g.op("Concat", *size, axis_i=0)
        size = g.op("Cast", size, to_i=sym_help.cast_pytorch_to_onnx['Long'])
        size = g.op("Concat", input_size, size, axis_i=0)
        scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
        return g.op("Resize",
                    input,
                    roi,
                    scales,
                    size,
                    coordinate_transformation_mode_s=coordinate_transformation_mode,
                    cubic_coeff_a_f=-0.75,  # only valid when mode="cubic"
                    mode_s=mode,  # nearest, linear, or cubic
                    nearest_mode_s="floor")
    else:  # if not sym_help._is_none(scales)
        if not input.type().dim():
            return sym_help._unimplemented("interpolate (with scales)", "missing input shape")
        scales = sym_help._interpolate_get_scales(g, scale_factor, input.type().dim())
        return g.op("Resize",
                    input,
                    roi,
                    scales,
                    coordinate_transformation_mode_s=coordinate_transformation_mode,
                    cubic_coeff_a_f=-0.75,  # only valid when mode="cubic"
                    mode_s=mode,  # nearest, linear, or cubic
                    nearest_mode_s="floor")  # only valid when mode="nearest"
def tensordot(g, input_a, input_b, dims_a, dims_b, out=None):
    if out is not None:
        _unimplemented("Tensordot", "Out parameter is not supported for tensordot.")

    dim_count_a = sym_help._get_tensor_rank(input_a)
    if dim_count_a is None:
        raise RuntimeError("Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.")

    dim_count_b = sym_help._get_tensor_rank(input_b)
    if dim_count_b is None:
        raise RuntimeError("Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.")

    dims_a = [(dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] for i in range(len(dims_a))]
    dims_b = [(dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] for i in range(len(dims_b))]

    left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)]
    left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)]

    new_input_a = permute(g, input_a, left_dims_a + dims_a)
    new_input_b = permute(g, input_b, dims_b + left_dims_b)

    input_shape = g.op("Shape", new_input_a)
    left_sizes_a = sym_help._slice_helper(g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)])
    shape_sizes = [left_sizes_a, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))]
    output_a = _reshape_from_tensor(g, new_input_a, shape_sizes)

    input_shape = g.op("Shape", output_a)
    slices = sym_help._slice_helper(g, input_shape, axes=[0], starts=[-1], ends=[maxsize])
    shape_sizes = [g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices]
    output_a = _reshape_from_tensor(g, new_input_a, shape_sizes)

    input_shape = g.op("Shape", new_input_b)
    left_sizes_b = sym_help._slice_helper(g, input_shape, axes=[0], starts=[len(dims_b)], ends=[maxsize])
    slices = sym_help._slice_helper(g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)])
    shape_sizes = [slices, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))]
    output_b = _reshape_from_tensor(g, new_input_b, shape_sizes)

    input_shape = g.op("Shape", output_b)
    slices = sym_help._slice_helper(g, input_shape, axes=[0], starts=[-1], ends=[maxsize])
    shape_sizes = [g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices]
    output_b = _reshape_from_tensor(g, new_input_b, shape_sizes)

    output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b]))

    shape_sizes = [left_sizes_a, left_sizes_b]
    return _reshape_from_tensor(g, output, shape_sizes)
def narrow(g, input, dim, start, length):
    from torch.onnx.symbolic_helper import _slice_helper
    end = g.op("Add", start, length)
    return _slice_helper(g,
                         input,
                         axes=dim,
                         starts=start,
                         ends=end,
                         dynamic_slice=True)
Example #14
0
def slice(g, self, dim, start, end, step):
    if (start.node().kind() != 'onnx::Constant' or
       end.node().kind() != 'onnx::Constant' or dim.node().kind() != 'onnx::Constant'):
        dynamic_slice = True
    else:
        start = [sym_help._parse_arg(start, 'i')]
        end = [sym_help._parse_arg(end, 'i')]
        dim = [sym_help._parse_arg(dim, 'i')]
        dynamic_slice = False
    return sym_help._slice_helper(g, self, axes=dim, starts=start, ends=end, steps=[step], dynamic_slice=dynamic_slice)
Example #15
0
def masked_scatter(g, self, mask, source):
    from torch.onnx.symbolic_opset9 import nonzero, expand_as, view, size
    index = nonzero(g, expand_as(g, mask, self))
    # NOTE: source can have more elements than needed.
    # It could also have arbitrary shape.
    # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
    source = view(g, source, torch.LongTensor([-1]))
    source = sym_help._slice_helper(g, source,
                                    axes=torch.LongTensor([0]),
                                    starts=torch.LongTensor([0]),
                                    ends=size(g, index, torch.LongTensor([0])),
                                    dynamic_slice=True)
    return g.op('ScatterND', self, index, source)
Example #16
0
 def symbolic(self, g, rois, *feats):
     rois = sym_help._slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
     roi_feats, _ = g.op(
         'ExperimentalDetectronROIFeatureExtractor',
         rois,
         *feats,
         output_size_i=self.inner.roi_layers[0].out_size[0],
         pyramid_scales_i=self.inner.featmap_strides,
         sampling_ratio_i=self.inner.roi_layers[0].sample_num,
         image_id_i=0,
         distribute_rois_between_levels_i=1,
         preserve_rois_order_i=1,
         outputs=2)
     return roi_feats
Example #17
0
def index_put(g, self, indices_list_value, values, accumulate=False):
    indices_list = sym_help._unpack_list(indices_list_value)
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
            g.op("Unsqueeze",
                 expand(g, ind, broadcast_index_shape, None),
                 axes_i=[-1]) for ind in indices_list
        ]
        index = g.op("Concat", *indices_list, axis_i=-1)
    else:
        broadcast_index_shape = g.op("Shape", index)
        index = g.op("Unsqueeze", index, axes_i=[-1])
    sub_data_shape = sym_help._slice_helper(g,
                                            g.op("Shape", self),
                                            axes=[0],
                                            starts=[len(indices_list)],
                                            ends=[maxsize])
    values_shape = g.op("Concat",
                        broadcast_index_shape,
                        sub_data_shape,
                        axis_i=0)
    values = g.op("Reshape", values, values_shape)

    if accumulate:
        dtype = self.type().scalarType()
        dtype = sym_help.scalar_type_to_onnx.index(
            sym_help.cast_pytorch_to_onnx[dtype])
        dtype = sym_help.scalar_type_to_pytorch_type[dtype]
        zeros = g.op("ConstantOfShape",
                     g.op("Shape", self),
                     value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
    else:
        result = g.op("ScatterND", self, index, values)

    return result
    def symbolic_fn(g, input, kernel_size, stride, padding, dilation,
                    ceil_mode):
        if not stride:
            stride = kernel_size
        kwargs = {
            "kernel_shape_i": tuple_fn(kernel_size),
            "pads_i": tuple_fn(padding) * 2,
            "strides_i": tuple_fn(stride),
            "ceil_mode_i": ceil_mode,
        }
        if set(tuple_fn(dilation)) != {1}:
            kwargs["dilations_i"] = tuple_fn(dilation)
        # easy but hacky way to get flattened indices values
        # to be used to convert the indices values to non-flattened.
        # In ONNX the indices are computed as a flatten 1-D tensor,
        # so the values in indices are in [0, N x C x D1 x ... x Dn).
        # To convert the indices to the same format used by Pytorch,
        # we first execute a maxpool with a kernel and stride of 1 on the same input.
        # This will result in a tensor of indices in which each index will have it's own value.
        # Using this tensor as a reference, we extract the first index of each axis and subtract
        # it from each index of this axis in the indices to convert.
        # This step will result in a tensor were each dimension has values of indices within
        # the dimension it is in.
        # For more information :
        # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
        if return_indices:
            r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
            _, flattened_indices = g.op(
                "MaxPool",
                input,
                outputs=2,
                kernel_shape_i=[1 for _ in range(ndims)],
                strides_i=[1 for _ in range(ndims)],
            )
            # convert indices to have non-flattened indices values
            from torch.onnx.symbolic_opset9 import sub

            s = sym_help._slice_helper(
                g,
                flattened_indices,
                axes=[2 + i for i in range(ndims)],
                starts=tuple_fn(0),
                ends=tuple_fn(1),
            )
            indices = sub(g, indices, s)
            return r, indices
        else:
            r = g.op("MaxPool", input, outputs=1, **kwargs)
            return r
Example #19
0
def masked_scatter(g, self, mask, source):
    index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
    # NOTE: source can have more elements than needed.
    # It could also have arbitrary shape.
    # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
    source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1]))
    source = symbolic_helper._slice_helper(
        g,
        source,
        axes=torch.LongTensor([0]),
        starts=torch.LongTensor([0]),
        ends=opset9.size(g, index, torch.LongTensor([0])),
        dynamic_slice=True,
    )
    return g.op("ScatterND", self, index, source)
Example #20
0
def sort(g, self, dim, decending, out=None):
    if out is not None:
        _unimplemented("Sort", "Out parameter is not supported for sort")

    # TODO: add decending to ONNX TopK so ascending sort is supported
    if not decending:
        _unimplemented("Sort", "Cannot sort in ascending order")

    shape_ = g.op("Shape", self)
    axis = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
    start = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int64))
    end = g.op("Constant", value_t=torch.tensor(dim + 1, dtype=torch.int64))
    slice_ = sym_help._slice_helper(g,
                                    shape_,
                                    axes=axis,
                                    starts=start,
                                    ends=end,
                                    steps=None,
                                    dynamic_slice=True)
    return g.op("TopK", self, slice_, axis_i=dim, outputs=2)
Example #21
0
def roi_feature_extractor_symbolics(g,
                                    rois,
                                    *feats,
                                    output_size=1,
                                    featmap_strides=1,
                                    sample_num=1):
    from torch.onnx.symbolic_helper import _slice_helper
    rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
    roi_feats = g.op(
        add_domain('ExperimentalDetectronROIFeatureExtractor'),
        rois,
        *feats,
        output_size_i=output_size,
        pyramid_scales_i=featmap_strides,
        sampling_ratio_i=sample_num,
        image_id_i=0,
        distribute_rois_between_levels_i=1,
        preserve_rois_order_i=0,
        aligned_i=1,
        outputs=1,
    )
    return roi_feats
def embedding_bag(g, embedding_matrix, indices, offsets, scale_grad_by_freq,
                  mode, sparse, per_sample_weights, include_last_offset,
                  padding_idx):
    if scale_grad_by_freq and sym_help._training_mode:
        return sym_help._onnx_unsupported(
            'embedding_bag with scale_grad_by_freq for training mode')
    if padding_idx is not None and padding_idx >= 0:
        raise RuntimeError('embedding_bag with padding_idx')

    loop_condition = g.op("Constant", value_t=torch.tensor(1))
    loop_condition = g.op("Cast", loop_condition, to_i=9)
    zero = g.op("Constant", value_t=torch.tensor([0]))

    indices_len = sym_help._unsqueeze_helper(
        g,
        sym_help._size_helper(g, indices,
                              g.op("Constant", value_t=torch.tensor(0))), [0])
    if not include_last_offset:
        offsets = [offsets, indices_len]
        offsets = g.op("Concat", *offsets, axis_i=0)

    # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by
    # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings.
    # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in.
    offsets_starts = sym_help._slice_helper(g,
                                            offsets,
                                            axes=[0],
                                            starts=[0],
                                            ends=[maxsize],
                                            steps=[1])
    offsets_ends = sym_help._slice_helper(g,
                                          offsets,
                                          axes=[0],
                                          starts=[1],
                                          ends=[maxsize],
                                          steps=[1])

    loop_len = sym_help._size_helper(g, offsets_ends,
                                     g.op("Constant", value_t=torch.tensor(0)))
    loop = g.op("Loop", loop_len, loop_condition)

    loop_block = _add_block(loop.node())
    block_input_iter = _add_input_to_block(loop_block)
    cond = _add_input_to_block(loop_block)

    indices_start = loop_block.op("Gather",
                                  offsets_starts,
                                  block_input_iter,
                                  axis_i=0)
    indices_end = loop_block.op("Gather",
                                offsets_ends,
                                block_input_iter,
                                axis_i=0)
    indices_start = sym_help._unsqueeze_helper(loop_block, indices_start, [0])
    indices_end = sym_help._unsqueeze_helper(loop_block, indices_end, [0])

    indices_row = loop_block.op("Slice", indices, indices_start, indices_end,
                                zero)
    embeddings = loop_block.op("Gather",
                               embedding_matrix,
                               indices_row,
                               axis_i=0)
    if not sym_help._is_none(per_sample_weights):
        per_sample_weights_row = loop_block.op("Slice", per_sample_weights,
                                               indices_start, indices_end,
                                               zero)
        per_sample_weights_row = sym_help._unsqueeze_helper(
            loop_block, per_sample_weights_row, [1])
        embeddings = loop_block.op("Mul", embeddings, per_sample_weights_row)
    if mode == 0:
        embeddings = sym_help._reducesum_helper(loop_block,
                                                embeddings,
                                                axes_i=[0],
                                                keepdims_i=0)
    elif mode == 1:
        embeddings = loop_block.op("ReduceMean",
                                   embeddings,
                                   axes_i=[0],
                                   keepdims_i=0)
    else:
        embeddings = loop_block.op("ReduceMax",
                                   embeddings,
                                   axes_i=[0],
                                   keepdims_i=0)

    cond_out = loop_block.op("Cast", loop_condition, to_i=9)
    _add_output_to_block(loop_block, cond_out)
    _add_output_to_block(loop_block, embeddings)

    # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
    # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
    return loop.node().output(), None, None, None
Example #23
0
def index_put(g, self, indices_list_value, values, accumulate=False):
    if symbolic_helper._is_packed_list(indices_list_value):
        indices_list = symbolic_helper._unpack_list(indices_list_value)
    else:
        indices_list = [indices_list_value]
    if symbolic_helper.is_caffe2_aten_fallback():
        args = [self] + indices_list + [values, accumulate]
        return g.at("index_put", *args)

    accumulate = symbolic_helper._parse_arg(accumulate, "b")

    if len(indices_list) == 0:
        return values

    if len(indices_list) > 1:
        for idx_ in range(len(indices_list)):
            if indices_list[idx_].type().scalarType() == "Bool":  # type: ignore[attr-defined]
                # TODO(justinchuby): Remove type ignore after #81112 is checked in.
                indices_list[idx_] = g.op("NonZero", indices_list[idx_])
        index = indices_list[0]

        for ind in indices_list[1:]:
            index = opset9.add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
            symbolic_helper._unsqueeze_helper(
                g, opset9.expand(g, ind, broadcast_index_shape, None), [-1]
            )
            for ind in indices_list
        ]
        index = g.op("Concat", *indices_list, axis_i=-1)
    else:
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains a single boolean input.
        #
        # index_put -> masked_fill
        #   * input index contains single tensor of Bool type (e.g.: %24 <- %23).
        #   * input value contains single element (e.g.: %18).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        #
        #
        # index_put -> masked_scatter
        #   * input index contains single tensor of Bool type (e.g.: %32 <- %31).
        #   * input value contains multiple elements (e.g.: %28).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        index = indices_list[0]
        bool_inp = index
        if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool":  # type: ignore[attr-defined]
            # TODO(justinchuby): Remove type ignore after #81112 is checked in.
            rank = symbolic_helper._get_tensor_rank(values)
            if rank is not None and rank == 0:
                return opset9.masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = symbolic_helper._unsqueeze_helper(g, index, [-1])
    sub_data_shape = symbolic_helper._slice_helper(
        g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize]
    )
    values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
    # Check if values is a singular value and expand accordingly
    rank = symbolic_helper._get_tensor_rank(values)
    if rank is not None and rank == 0:
        values = opset9.expand(g, values, values_shape, None)
    values = symbolic_helper._reshape_helper(g, values, values_shape)

    dtype = self.type().scalarType()
    if dtype is not None and dtype != values.type().scalarType():
        values = g.op("Cast", values, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
    dtype = symbolic_helper.scalar_type_to_onnx.index(
        symbolic_helper.cast_pytorch_to_onnx[dtype]
    )
    dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype]

    if accumulate:
        zeros = g.op(
            "ConstantOfShape",
            g.op("Shape", self),
            value_t=torch.tensor([0], dtype=dtype),
        )
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
    else:
        result = g.op("ScatterND", self, index, values)

    return result
def index_put(g, self, indices_list_value, values, accumulate=False):
    indices_list = sym_help._unpack_list(indices_list_value)
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
            sym_help._unsqueeze_helper(
                g, expand(g, ind, broadcast_index_shape, None), [-1])
            for ind in indices_list
        ]
        index = g.op("Concat", *indices_list, axis_i=-1)
    else:
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains boolean inputs
        #
        # index_put -> masked_fill
        #
        # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %6 : None = prim::Constant()
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %8 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::ne(%mask, %some_const)
        #   %26 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]()
        #   %27 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %11 : Device = prim::Constant[value="cpu"]()
        #   %12 : None = prim::Constant()
        #   %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %15 : None = prim::Constant()
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %22 : int[] = prim::Constant[value=[-1]]()
        #   %23 : Tensor = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        #
        # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %3 : Tensor = onnx::Equal(%0, %some_const)
        #   %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3)
        #   %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4)
        #   %19 : Tensor = onnx::Cast[to=9](%12)
        #   %20 : Tensor = onnx::Constant[value={1}]()
        #   %21 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = onnx::Where(%19, %20, %0)
        #   return (%21)
        #
        # index_put -> masked_scatter
        #
        # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %6 : None = prim::Constant()
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %34 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]()
        #   %35 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %18 : Device = prim::Constant[value="cpu"]()
        #   %19 : None = prim::Constant()
        #   %36 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %22 : None = prim::Constant()
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Tensor = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        #
        # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %3 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #               = onnx::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %4 : Tensor = onnx::Equal(%0, %some_const)
        #   %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4)
        #   %13 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%5)
        #   %19 : Tensor = onnx::Shape(%0)
        #   %20 : Tensor = onnx::Expand(%13, %19)
        #   %21 : Tensor = onnx::NonZero(%20)
        #   %22 : Tensor = onnx::Transpose[perm=[1, 0]](%21)
        #   %23 : Tensor = onnx::Constant[value={-1}]()
        #   %24 : Tensor = onnx::Reshape(%3, %23)
        #   %25 : Tensor = onnx::Shape(%22)
        #   %27 : Tensor = onnx::Constant[value={0}]()
        #   %28 : Tensor = onnx::Gather[axis=0](%25, %27)
        #   %29 : Tensor = onnx::Constant[value={0}]()
        #   %30 : Tensor = onnx::Unsqueeze[axes=[0]](%29)
        #   %31 : Tensor = onnx::Unsqueeze[axes=[0]](%28)
        #   %32 : Tensor = onnx::Constant[value={0}]()
        #   %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32)
        #   %34 : Tensor = onnx::Slice(%24, %30, %31, %33)
        #   %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = onnx::ScatterND(%0, %22, %34)
        #   return (%35)

        bool_inp = list(index.node().inputs())[0]
        if bool_inp.type() is not None and bool_inp.type().scalarType(
        ) == 'Bool':
            rank = sym_help._get_tensor_rank(values)
            if rank is not None and rank == 0:
                from torch.onnx.symbolic_opset9 import masked_fill
                return masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = sym_help._unsqueeze_helper(g, index, [-1])
    sub_data_shape = sym_help._slice_helper(g,
                                            g.op("Shape", self),
                                            axes=[0],
                                            starts=[len(indices_list)],
                                            ends=[maxsize])
    values_shape = g.op("Concat",
                        broadcast_index_shape,
                        sub_data_shape,
                        axis_i=0)
    values = g.op("Reshape", values, values_shape)

    if accumulate:
        dtype = self.type().scalarType()
        dtype = sym_help.scalar_type_to_onnx.index(
            sym_help.cast_pytorch_to_onnx[dtype])
        dtype = sym_help.scalar_type_to_pytorch_type[dtype]
        zeros = g.op("ConstantOfShape",
                     g.op("Shape", self),
                     value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
    else:
        result = g.op("ScatterND", self, index, values)

    return result
def index_put(g, self, indices_list_value, values, accumulate=False):
    if sym_help._is_packed_list(indices_list_value):
        indices_list = sym_help._unpack_list(indices_list_value)
    else:
        indices_list = [indices_list_value]
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    if len(indices_list) == 0:
        return values

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
            sym_help._unsqueeze_helper(
                g, expand(g, ind, broadcast_index_shape, None), [-1])
            for ind in indices_list
        ]
        index = g.op("Concat", *indices_list, axis_i=-1)
    else:
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains boolean inputs
        #
        # index_put -> masked_fill
        #   * input index contains single tensor of Bool type (e.g.: %24 <- %23).
        #   * input value contains single element (e.g.: %18).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        #
        #
        # index_put -> masked_scatter
        #   * input index contains single tensor of Bool type (e.g.: %32 <- %31).
        #   * input value contains multiple elements (e.g.: %28).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        bool_inp = index
        if bool_inp.type() is not None and bool_inp.type().scalarType(
        ) == 'Bool':
            rank = sym_help._get_tensor_rank(values)
            if rank is not None and rank == 0:
                from torch.onnx.symbolic_opset9 import masked_fill
                return masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = sym_help._unsqueeze_helper(g, index, [-1])
    sub_data_shape = sym_help._slice_helper(g,
                                            g.op("Shape", self),
                                            axes=[0],
                                            starts=[len(indices_list)],
                                            ends=[maxsize])
    values_shape = g.op("Concat",
                        broadcast_index_shape,
                        sub_data_shape,
                        axis_i=0)
    # Check if values is a singular value and expand accordingly
    rank = sym_help._get_tensor_rank(values)
    if rank is not None and rank == 0:
        values = expand(g, values, values_shape, None)
    values = g.op("Reshape", values, values_shape)

    dtype = self.type().scalarType()
    if dtype is not None and dtype != values.type().scalarType():
        values = g.op("Cast",
                      values,
                      to_i=sym_help.cast_pytorch_to_onnx[dtype])
    dtype = sym_help.scalar_type_to_onnx.index(
        sym_help.cast_pytorch_to_onnx[dtype])
    dtype = sym_help.scalar_type_to_pytorch_type[dtype]

    if accumulate:
        zeros = g.op("ConstantOfShape",
                     g.op("Shape", self),
                     value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
    else:
        result = g.op("ScatterND", self, index, values)

    return result
Example #26
0
def narrow(g, input, dim, start, length):
    end = g.op("Add", start, length)
    return symbolic_helper._slice_helper(
        g, input, axes=dim, starts=start, ends=end, dynamic_slice=True
    )