def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio, pool_mode, aligned): from torch.onnx.symbolic_opset9 import sub, squeeze from torch.onnx.symbolic_helper import _slice_helper from torch.onnx import TensorProtoDataType # batch_indices = rois[:, 0].long() batch_indices = _slice_helper(g, rois, axes=[1], starts=[0], ends=[1]) batch_indices = squeeze(g, batch_indices, 1) batch_indices = g.op('Cast', batch_indices, to_i=TensorProtoDataType.INT64) # rois = rois[:, 1:] rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5]) if aligned: # rois -= 0.5/spatial_scale aligned_offset = g.op('Constant', value_t=torch.tensor([0.5 / spatial_scale], dtype=torch.float32)) rois = sub(g, rois, aligned_offset) # roi align return g.op('RoiAlign', input, rois, batch_indices, output_height_i=output_size[0], output_width_i=output_size[1], spatial_scale_f=spatial_scale, sampling_ratio_i=max(0, sampling_ratio), mode_s=pool_mode)
def roll(g, input, shifts, dims): assert len(shifts) == len(dims) result = input for i in range(len(shifts)): shapes = [] shape = _slice_helper(g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[maxsize]) shapes.append(shape) shape = _slice_helper(g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]]) shapes.append(shape) result = g.op("Concat", *shapes, axis_i=dims[i]) return result
def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio, pool_mode, aligned): has_custom_op = False try: import os.path as osp from mmcv.ops import get_onnxruntime_op_path ort_op_path = get_onnxruntime_op_path() has_custom_op = osp.exists(ort_op_path) except ImportError: pass if has_custom_op: return g.op( 'mmcv::MMCVRoiAlign', input, rois, aligned_height_i=output_size[0], aligned_width_i=output_size[1], spatial_scale_f=spatial_scale, sampling_ratio_i=max(0, sampling_ratio), pool_mode_s=pool_mode, aligned_i=aligned) from torch.onnx.symbolic_opset9 import sub, squeeze from torch.onnx.symbolic_helper import _slice_helper from torch.onnx import TensorProtoDataType # batch_indices = rois[:, 0].long() batch_indices = _slice_helper(g, rois, axes=[1], starts=[0], ends=[1]) batch_indices = squeeze(g, batch_indices, 1) batch_indices = g.op( 'Cast', batch_indices, to_i=TensorProtoDataType.INT64) # rois = rois[:, 1:] rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5]) if aligned: # rois -= 0.5/spatial_scale aligned_offset = g.op( 'Constant', value_t=torch.tensor([0.5 / spatial_scale], dtype=torch.float32)) rois = sub(g, rois, aligned_offset) # roi align return g.op( 'RoiAlign', input, rois, batch_indices, output_height_i=output_size[0], output_width_i=output_size[1], spatial_scale_f=spatial_scale, sampling_ratio_i=max(0, sampling_ratio), mode_s=pool_mode)
def symbolic_fn(g, input, output_size, *args): scales, align_corners = sym_help._get_interpolate_attributes(g, interpolate_mode, args) align_corners = sym_help._maybe_get_scalar(align_corners) coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \ else "align_corners" if align_corners else "pytorch_half_pixel" empty_tensor = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) if scales is None: input_size = g.op("Shape", input) input_size_beg = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) output_size = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Long"]) output_size = g.op("Concat", input_size_beg, output_size, axis_i=0) scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) return g.op("Resize", input, empty_tensor, # roi only takes effect whith coordinate_transformation_mode="tf_crop_and_resize" scales, # scales is not needed since we are sending out_size output_size, coordinate_transformation_mode_s=coordinate_transformation_mode, cubic_coeff_a_f=-0.75, # only valid when mode="cubic" mode_s=interpolate_mode, # nearest, linear, or cubic nearest_mode_s="floor") # only valid when mode="nearest" else: return g.op("Resize", input, empty_tensor, # roi only takes effect with coordinate_transformation_mode="tf_crop_and_resize" scales, # scales is not needed since we are sending out_size coordinate_transformation_mode_s=coordinate_transformation_mode, cubic_coeff_a_f=-0.75, # only valid when mode="cubic" mode_s=interpolate_mode, # nearest, linear, or cubic nearest_mode_s="floor") # only valid when mode="nearest"
def symbolic_fn(g, input, output_size, align_corners=None): if align_corners: return _unimplemented(name, "align_corners == True") output_size = sym_help._maybe_get_const(output_size, 'is') if sym_help._is_value(output_size): offset = 2 offsets = g.op("Constant", value_t=torch.tensor([1. for i in range(offset)])) dividend = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Float"]) divisor = sym_help._slice_helper(g, g.op("Shape", input), axes=[0], ends=[dim], starts=[offset]) divisor = g.op("Cast", divisor, to_i=sym_help.cast_pytorch_to_onnx["Float"]) scale_dims = g.op("Div", dividend, divisor) scales = g.op("Concat", offsets, scale_dims, axis_i=0) else: scales_constant = [ 1. if i < 2 else float(output_size[-(dim - i)]) / float(input.type().sizes()[-(dim - i)]) for i in range(0, dim) ] scales = g.op("Constant", value_t=torch.tensor(scales_constant)) return g.op("Resize", input, scales, mode_s=interpolate_mode)
def slice(g, self, *args): if len(args) == 4: # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor dim, start, end, step = args elif len(args) == 3: # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[] start, end, step = args dim = 0 else: raise NotImplementedError("Unknown aten::slice signature") is_start_none = start.node().kind() == "prim::Constant" and start.type().kind() == "NoneType" is_end_none = end.node().kind() == "prim::Constant" and end.type().kind() == "NoneType" is_start_onnx_const = start.node().kind() == "onnx::Constant" is_end_onnx_const = end.node().kind() == "onnx::Constant" step = sym_help._parse_arg(step, "i") if (not is_start_none and not is_start_onnx_const) or \ (not isinstance(end, int) and not is_end_none and not is_end_onnx_const) or \ (not isinstance(dim, int) and dim.node().kind() != "onnx::Constant"): dynamic_slice = True if is_start_none: start = g.op("Constant", value_t=torch.tensor(0)) if is_end_none: end = g.op("Constant", value_t=torch.tensor(9223372036854775807)) else: start = [0 if is_start_none else sym_help._parse_arg(start, "i")] end = [9223372036854775807 if is_end_none else sym_help._parse_arg(end, "i")] dim = [sym_help._parse_arg(dim, "i")] dynamic_slice = False return sym_help._slice_helper(g, self, axes=dim, starts=start, ends=end, steps=[step], dynamic_slice=dynamic_slice)
def symbolic_fn(g, input, output_size, align_corners=None): align_corners = sym_help._maybe_get_scalar(align_corners) coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \ else "align_corners" if align_corners else "pytorch_half_pixel" empty_tensor = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) input_size = g.op("Shape", input) input_size_beg = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) output_size = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Long"]) output_size = g.op("Concat", input_size_beg, output_size, axis_i=0) if interpolate_mode == "bilinear": return g.op("Plugin", input, name_s="Bilinear", info_s=json.dumps({"scale_factor": [2, 2]})) else: return g.op("Upsample", input, scales_f=[1, 2, 2], mode_s=interpolate_mode)
def symbolic_fn(g, input, output_size, align_corners=None): align_corners = sym_help._maybe_get_scalar(align_corners) coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \ else "align_corners" if align_corners else "pytorch_half_pixel" empty_tensor = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) input_size = g.op("Shape", input) input_size_beg = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) output_size = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Long"]) output_size = g.op("Concat", input_size_beg, output_size, axis_i=0) # ************************************ upsample change if interpolate_mode == "bilinear": # https://github.com/dlunion/tensorRTIntegrate/blob/7543614a4ff5f6e022526a4e137d47f4e63f5f96/plugin_onnx_export.py#L33 return g.op("Plugin", input, name_s="Bilinear", info_s=json.dumps({"scale_factor": [2, 2]})) else: return g.op("Upsample", input, scales_f=[1, 2, 2], mode_s=interpolate_mode)
def flip(g, input, dims): return sym_help._slice_helper(g, input, axes=dims, starts=[-1] * len(dims), ends=[-9223372036854775807] * len(dims), steps=[-1] * len(dims))
def slice(g, self, *args): if len(args) == 4: # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor dim, start, end, step = args elif len(args) == 3: # aten::slice(t[] l, int start, int end, int step) -> t[] start, end, step = args dim = 0 else: raise NotImplementedError("Unknown aten::slice signature") step = sym_help._parse_arg(step, 'i') if (start.node().kind() != 'onnx::Constant' or (not isinstance(end, int) and end.node().kind() != 'onnx::Constant') or (not isinstance(dim, int) and dim.node().kind() != 'onnx::Constant')): dynamic_slice = True else: start = [sym_help._parse_arg(start, 'i')] end = [sym_help._parse_arg(end, 'i')] dim = [sym_help._parse_arg(dim, 'i')] dynamic_slice = False return sym_help._slice_helper(g, self, axes=dim, starts=start, ends=end, steps=[step], dynamic_slice=dynamic_slice)
def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_scale_factor): mode = sym_help._maybe_get_const(mode, 's') if 'linear' in mode: mode = 'linear' if 'cubic' in mode: mode = 'cubic' align_corners = sym_help._maybe_get_const(align_corners, 'b') align_corners = False if not isinstance(align_corners, bool) else align_corners coordinate_transformation_mode = "asymmetric" if mode == "nearest" \ else "align_corners" if align_corners else "pytorch_half_pixel" # roi only takes effect with coordinate_transformation_mode="tf_crop_and_resize" roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) if not sym_help._is_none(size) : input_size = g.op("Shape", input) input_size = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) # in some cases size is not a packed list but size is a scalar # We need to also verify that (sym_help._maybe_get_const(size, 't').dim() == 0) # but this information is not always available. Try to get the dim, # and if not assume that it is not a scalar. try: is_scalar = not sym_help._is_packed_list(size) and ((sym_help._maybe_get_const(size, 't').dim() == 0)) except AttributeError: is_scalar = not sym_help._is_packed_list(size) if not is_scalar: warnings.warn("Cannot verify if the output_size is a scalar " "while exporting interpolate. Assuming that it is not a scalar.") if is_scalar: if not input.type().dim(): return sym_help._unimplemented("interpolate (with a scalar output_size)", "missing input shape (try giving an array of output_size values)") size = unsqueeze(g, size, 0) size = [size for i in range(input.type().dim() - 2)] size = g.op("Concat", *size, axis_i=0) size = g.op("Cast", size, to_i=sym_help.cast_pytorch_to_onnx['Long']) size = g.op("Concat", input_size, size, axis_i=0) scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) return g.op("Resize", input, roi, scales, size, coordinate_transformation_mode_s=coordinate_transformation_mode, cubic_coeff_a_f=-0.75, # only valid when mode="cubic" mode_s=mode, # nearest, linear, or cubic nearest_mode_s="floor") else: # if not sym_help._is_none(scales) if not input.type().dim(): return sym_help._unimplemented("interpolate (with scales)", "missing input shape") scales = sym_help._interpolate_get_scales(g, scale_factor, input.type().dim()) return g.op("Resize", input, roi, scales, coordinate_transformation_mode_s=coordinate_transformation_mode, cubic_coeff_a_f=-0.75, # only valid when mode="cubic" mode_s=mode, # nearest, linear, or cubic nearest_mode_s="floor") # only valid when mode="nearest"
def tensordot(g, input_a, input_b, dims_a, dims_b, out=None): if out is not None: _unimplemented("Tensordot", "Out parameter is not supported for tensordot.") dim_count_a = sym_help._get_tensor_rank(input_a) if dim_count_a is None: raise RuntimeError("Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.") dim_count_b = sym_help._get_tensor_rank(input_b) if dim_count_b is None: raise RuntimeError("Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.") dims_a = [(dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] for i in range(len(dims_a))] dims_b = [(dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] for i in range(len(dims_b))] left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)] left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)] new_input_a = permute(g, input_a, left_dims_a + dims_a) new_input_b = permute(g, input_b, dims_b + left_dims_b) input_shape = g.op("Shape", new_input_a) left_sizes_a = sym_help._slice_helper(g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)]) shape_sizes = [left_sizes_a, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))] output_a = _reshape_from_tensor(g, new_input_a, shape_sizes) input_shape = g.op("Shape", output_a) slices = sym_help._slice_helper(g, input_shape, axes=[0], starts=[-1], ends=[maxsize]) shape_sizes = [g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices] output_a = _reshape_from_tensor(g, new_input_a, shape_sizes) input_shape = g.op("Shape", new_input_b) left_sizes_b = sym_help._slice_helper(g, input_shape, axes=[0], starts=[len(dims_b)], ends=[maxsize]) slices = sym_help._slice_helper(g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)]) shape_sizes = [slices, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))] output_b = _reshape_from_tensor(g, new_input_b, shape_sizes) input_shape = g.op("Shape", output_b) slices = sym_help._slice_helper(g, input_shape, axes=[0], starts=[-1], ends=[maxsize]) shape_sizes = [g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices] output_b = _reshape_from_tensor(g, new_input_b, shape_sizes) output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b])) shape_sizes = [left_sizes_a, left_sizes_b] return _reshape_from_tensor(g, output, shape_sizes)
def narrow(g, input, dim, start, length): from torch.onnx.symbolic_helper import _slice_helper end = g.op("Add", start, length) return _slice_helper(g, input, axes=dim, starts=start, ends=end, dynamic_slice=True)
def slice(g, self, dim, start, end, step): if (start.node().kind() != 'onnx::Constant' or end.node().kind() != 'onnx::Constant' or dim.node().kind() != 'onnx::Constant'): dynamic_slice = True else: start = [sym_help._parse_arg(start, 'i')] end = [sym_help._parse_arg(end, 'i')] dim = [sym_help._parse_arg(dim, 'i')] dynamic_slice = False return sym_help._slice_helper(g, self, axes=dim, starts=start, ends=end, steps=[step], dynamic_slice=dynamic_slice)
def masked_scatter(g, self, mask, source): from torch.onnx.symbolic_opset9 import nonzero, expand_as, view, size index = nonzero(g, expand_as(g, mask, self)) # NOTE: source can have more elements than needed. # It could also have arbitrary shape. # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. source = view(g, source, torch.LongTensor([-1])) source = sym_help._slice_helper(g, source, axes=torch.LongTensor([0]), starts=torch.LongTensor([0]), ends=size(g, index, torch.LongTensor([0])), dynamic_slice=True) return g.op('ScatterND', self, index, source)
def symbolic(self, g, rois, *feats): rois = sym_help._slice_helper(g, rois, axes=[1], starts=[1], ends=[5]) roi_feats, _ = g.op( 'ExperimentalDetectronROIFeatureExtractor', rois, *feats, output_size_i=self.inner.roi_layers[0].out_size[0], pyramid_scales_i=self.inner.featmap_strides, sampling_ratio_i=self.inner.roi_layers[0].sample_num, image_id_i=0, distribute_rois_between_levels_i=1, preserve_rois_order_i=1, outputs=2) return roi_feats
def index_put(g, self, indices_list_value, values, accumulate=False): indices_list = sym_help._unpack_list(indices_list_value) if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: args = [self] + indices_list + [values, accumulate] return g.op("ATen", *args, operator_s='index_put') from torch.onnx.symbolic_opset9 import add, expand accumulate = sym_help._parse_arg(accumulate, 'b') index = indices_list[0] if len(indices_list) > 1: for ind in indices_list[1:]: index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ g.op("Unsqueeze", expand(g, ind, broadcast_index_shape, None), axes_i=[-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: broadcast_index_shape = g.op("Shape", index) index = g.op("Unsqueeze", index, axes_i=[-1]) sub_data_shape = sym_help._slice_helper(g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) values = g.op("Reshape", values, values_shape) if accumulate: dtype = self.type().scalarType() dtype = sym_help.scalar_type_to_onnx.index( sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): if not stride: stride = kernel_size kwargs = { "kernel_shape_i": tuple_fn(kernel_size), "pads_i": tuple_fn(padding) * 2, "strides_i": tuple_fn(stride), "ceil_mode_i": ceil_mode, } if set(tuple_fn(dilation)) != {1}: kwargs["dilations_i"] = tuple_fn(dilation) # easy but hacky way to get flattened indices values # to be used to convert the indices values to non-flattened. # In ONNX the indices are computed as a flatten 1-D tensor, # so the values in indices are in [0, N x C x D1 x ... x Dn). # To convert the indices to the same format used by Pytorch, # we first execute a maxpool with a kernel and stride of 1 on the same input. # This will result in a tensor of indices in which each index will have it's own value. # Using this tensor as a reference, we extract the first index of each axis and subtract # it from each index of this axis in the indices to convert. # This step will result in a tensor were each dimension has values of indices within # the dimension it is in. # For more information : # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 if return_indices: r, indices = g.op("MaxPool", input, outputs=2, **kwargs) _, flattened_indices = g.op( "MaxPool", input, outputs=2, kernel_shape_i=[1 for _ in range(ndims)], strides_i=[1 for _ in range(ndims)], ) # convert indices to have non-flattened indices values from torch.onnx.symbolic_opset9 import sub s = sym_help._slice_helper( g, flattened_indices, axes=[2 + i for i in range(ndims)], starts=tuple_fn(0), ends=tuple_fn(1), ) indices = sub(g, indices, s) return r, indices else: r = g.op("MaxPool", input, outputs=1, **kwargs) return r
def masked_scatter(g, self, mask, source): index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) # NOTE: source can have more elements than needed. # It could also have arbitrary shape. # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1])) source = symbolic_helper._slice_helper( g, source, axes=torch.LongTensor([0]), starts=torch.LongTensor([0]), ends=opset9.size(g, index, torch.LongTensor([0])), dynamic_slice=True, ) return g.op("ScatterND", self, index, source)
def sort(g, self, dim, decending, out=None): if out is not None: _unimplemented("Sort", "Out parameter is not supported for sort") # TODO: add decending to ONNX TopK so ascending sort is supported if not decending: _unimplemented("Sort", "Cannot sort in ascending order") shape_ = g.op("Shape", self) axis = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) start = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int64)) end = g.op("Constant", value_t=torch.tensor(dim + 1, dtype=torch.int64)) slice_ = sym_help._slice_helper(g, shape_, axes=axis, starts=start, ends=end, steps=None, dynamic_slice=True) return g.op("TopK", self, slice_, axis_i=dim, outputs=2)
def roi_feature_extractor_symbolics(g, rois, *feats, output_size=1, featmap_strides=1, sample_num=1): from torch.onnx.symbolic_helper import _slice_helper rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5]) roi_feats = g.op( add_domain('ExperimentalDetectronROIFeatureExtractor'), rois, *feats, output_size_i=output_size, pyramid_scales_i=featmap_strides, sampling_ratio_i=sample_num, image_id_i=0, distribute_rois_between_levels_i=1, preserve_rois_order_i=0, aligned_i=1, outputs=1, ) return roi_feats
def embedding_bag(g, embedding_matrix, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx): if scale_grad_by_freq and sym_help._training_mode: return sym_help._onnx_unsupported( 'embedding_bag with scale_grad_by_freq for training mode') if padding_idx is not None and padding_idx >= 0: raise RuntimeError('embedding_bag with padding_idx') loop_condition = g.op("Constant", value_t=torch.tensor(1)) loop_condition = g.op("Cast", loop_condition, to_i=9) zero = g.op("Constant", value_t=torch.tensor([0])) indices_len = sym_help._unsqueeze_helper( g, sym_help._size_helper(g, indices, g.op("Constant", value_t=torch.tensor(0))), [0]) if not include_last_offset: offsets = [offsets, indices_len] offsets = g.op("Concat", *offsets, axis_i=0) # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings. # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in. offsets_starts = sym_help._slice_helper(g, offsets, axes=[0], starts=[0], ends=[maxsize], steps=[1]) offsets_ends = sym_help._slice_helper(g, offsets, axes=[0], starts=[1], ends=[maxsize], steps=[1]) loop_len = sym_help._size_helper(g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))) loop = g.op("Loop", loop_len, loop_condition) loop_block = _add_block(loop.node()) block_input_iter = _add_input_to_block(loop_block) cond = _add_input_to_block(loop_block) indices_start = loop_block.op("Gather", offsets_starts, block_input_iter, axis_i=0) indices_end = loop_block.op("Gather", offsets_ends, block_input_iter, axis_i=0) indices_start = sym_help._unsqueeze_helper(loop_block, indices_start, [0]) indices_end = sym_help._unsqueeze_helper(loop_block, indices_end, [0]) indices_row = loop_block.op("Slice", indices, indices_start, indices_end, zero) embeddings = loop_block.op("Gather", embedding_matrix, indices_row, axis_i=0) if not sym_help._is_none(per_sample_weights): per_sample_weights_row = loop_block.op("Slice", per_sample_weights, indices_start, indices_end, zero) per_sample_weights_row = sym_help._unsqueeze_helper( loop_block, per_sample_weights_row, [1]) embeddings = loop_block.op("Mul", embeddings, per_sample_weights_row) if mode == 0: embeddings = sym_help._reducesum_helper(loop_block, embeddings, axes_i=[0], keepdims_i=0) elif mode == 1: embeddings = loop_block.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) else: embeddings = loop_block.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) cond_out = loop_block.op("Cast", loop_condition, to_i=9) _add_output_to_block(loop_block, cond_out) _add_output_to_block(loop_block, embeddings) # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. return loop.node().output(), None, None, None
def index_put(g, self, indices_list_value, values, accumulate=False): if symbolic_helper._is_packed_list(indices_list_value): indices_list = symbolic_helper._unpack_list(indices_list_value) else: indices_list = [indices_list_value] if symbolic_helper.is_caffe2_aten_fallback(): args = [self] + indices_list + [values, accumulate] return g.at("index_put", *args) accumulate = symbolic_helper._parse_arg(accumulate, "b") if len(indices_list) == 0: return values if len(indices_list) > 1: for idx_ in range(len(indices_list)): if indices_list[idx_].type().scalarType() == "Bool": # type: ignore[attr-defined] # TODO(justinchuby): Remove type ignore after #81112 is checked in. indices_list[idx_] = g.op("NonZero", indices_list[idx_]) index = indices_list[0] for ind in indices_list[1:]: index = opset9.add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ symbolic_helper._unsqueeze_helper( g, opset9.expand(g, ind, broadcast_index_shape, None), [-1] ) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains a single boolean input. # # index_put -> masked_fill # * input index contains single tensor of Bool type (e.g.: %24 <- %23). # * input value contains single element (e.g.: %18). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # # index_put -> masked_scatter # * input index contains single tensor of Bool type (e.g.: %32 <- %31). # * input value contains multiple elements (e.g.: %28). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) index = indices_list[0] bool_inp = index if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool": # type: ignore[attr-defined] # TODO(justinchuby): Remove type ignore after #81112 is checked in. rank = symbolic_helper._get_tensor_rank(values) if rank is not None and rank == 0: return opset9.masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = symbolic_helper._unsqueeze_helper(g, index, [-1]) sub_data_shape = symbolic_helper._slice_helper( g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize] ) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) # Check if values is a singular value and expand accordingly rank = symbolic_helper._get_tensor_rank(values) if rank is not None and rank == 0: values = opset9.expand(g, values, values_shape, None) values = symbolic_helper._reshape_helper(g, values, values_shape) dtype = self.type().scalarType() if dtype is not None and dtype != values.type().scalarType(): values = g.op("Cast", values, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]) dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype] if accumulate: zeros = g.op( "ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype), ) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def index_put(g, self, indices_list_value, values, accumulate=False): indices_list = sym_help._unpack_list(indices_list_value) if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: args = [self] + indices_list + [values, accumulate] return g.op("ATen", *args, operator_s='index_put') from torch.onnx.symbolic_opset9 import add, expand accumulate = sym_help._parse_arg(accumulate, 'b') index = indices_list[0] if len(indices_list) > 1: for ind in indices_list[1:]: index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ sym_help._unsqueeze_helper( g, expand(g, ind, broadcast_index_shape, None), [-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains boolean inputs # # index_put -> masked_fill # # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %6 : None = prim::Constant() # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %8 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::ne(%mask, %some_const) # %26 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]() # %27 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %11 : Device = prim::Constant[value="cpu"]() # %12 : None = prim::Constant() # %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %15 : None = prim::Constant() # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %22 : int[] = prim::Constant[value=[-1]]() # %23 : Tensor = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %3 : Tensor = onnx::Equal(%0, %some_const) # %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3) # %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4) # %19 : Tensor = onnx::Cast[to=9](%12) # %20 : Tensor = onnx::Constant[value={1}]() # %21 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = onnx::Where(%19, %20, %0) # return (%21) # # index_put -> masked_scatter # # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %6 : None = prim::Constant() # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %34 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]() # %35 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %18 : Device = prim::Constant[value="cpu"]() # %19 : None = prim::Constant() # %36 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %22 : None = prim::Constant() # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Tensor = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) # # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %3 : Float(8, strides=[1], requires_grad=0, device=cpu) # = onnx::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %4 : Tensor = onnx::Equal(%0, %some_const) # %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4) # %13 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%5) # %19 : Tensor = onnx::Shape(%0) # %20 : Tensor = onnx::Expand(%13, %19) # %21 : Tensor = onnx::NonZero(%20) # %22 : Tensor = onnx::Transpose[perm=[1, 0]](%21) # %23 : Tensor = onnx::Constant[value={-1}]() # %24 : Tensor = onnx::Reshape(%3, %23) # %25 : Tensor = onnx::Shape(%22) # %27 : Tensor = onnx::Constant[value={0}]() # %28 : Tensor = onnx::Gather[axis=0](%25, %27) # %29 : Tensor = onnx::Constant[value={0}]() # %30 : Tensor = onnx::Unsqueeze[axes=[0]](%29) # %31 : Tensor = onnx::Unsqueeze[axes=[0]](%28) # %32 : Tensor = onnx::Constant[value={0}]() # %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32) # %34 : Tensor = onnx::Slice(%24, %30, %31, %33) # %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = onnx::ScatterND(%0, %22, %34) # return (%35) bool_inp = list(index.node().inputs())[0] if bool_inp.type() is not None and bool_inp.type().scalarType( ) == 'Bool': rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: from torch.onnx.symbolic_opset9 import masked_fill return masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = sym_help._unsqueeze_helper(g, index, [-1]) sub_data_shape = sym_help._slice_helper(g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) values = g.op("Reshape", values, values_shape) if accumulate: dtype = self.type().scalarType() dtype = sym_help.scalar_type_to_onnx.index( sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def index_put(g, self, indices_list_value, values, accumulate=False): if sym_help._is_packed_list(indices_list_value): indices_list = sym_help._unpack_list(indices_list_value) else: indices_list = [indices_list_value] if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: args = [self] + indices_list + [values, accumulate] return g.op("ATen", *args, operator_s='index_put') from torch.onnx.symbolic_opset9 import add, expand accumulate = sym_help._parse_arg(accumulate, 'b') if len(indices_list) == 0: return values index = indices_list[0] if len(indices_list) > 1: for ind in indices_list[1:]: index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ sym_help._unsqueeze_helper( g, expand(g, ind, broadcast_index_shape, None), [-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains boolean inputs # # index_put -> masked_fill # * input index contains single tensor of Bool type (e.g.: %24 <- %23). # * input value contains single element (e.g.: %18). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # # index_put -> masked_scatter # * input index contains single tensor of Bool type (e.g.: %32 <- %31). # * input value contains multiple elements (e.g.: %28). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) bool_inp = index if bool_inp.type() is not None and bool_inp.type().scalarType( ) == 'Bool': rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: from torch.onnx.symbolic_opset9 import masked_fill return masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = sym_help._unsqueeze_helper(g, index, [-1]) sub_data_shape = sym_help._slice_helper(g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) # Check if values is a singular value and expand accordingly rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: values = expand(g, values, values_shape, None) values = g.op("Reshape", values, values_shape) dtype = self.type().scalarType() if dtype is not None and dtype != values.type().scalarType(): values = g.op("Cast", values, to_i=sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_onnx.index( sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] if accumulate: zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def narrow(g, input, dim, start, length): end = g.op("Add", start, length) return symbolic_helper._slice_helper( g, input, axes=dim, starts=start, ends=end, dynamic_slice=True )