def split(g, self, split_size_or_sizes, dim, _outputs=None): if not sym_help._is_split_static(split_size_or_sizes, _outputs): split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) if _outputs is None: return split_out # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. if sym_help._is_packed_list(split_size_or_sizes) and \ len(sym_help._unpack_list(split_size_or_sizes)) == _outputs: split_sizes = [sym_help._unsqueeze_helper(g, v, [0]) for v in sym_help._unpack_list(split_size_or_sizes)] start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) res = [] for i in range(_outputs): end = g.op("Add", start, split_sizes[i]) # split_sizes is a list of same length as _outputs res.append(g.op("Slice", self, start, end, axis)) start = end return res return [g.op("SequenceAt", split_out, g.op("Constant", value_t=torch.tensor([i], dtype=torch.long))) for i in range(_outputs)] split_val = split_size_or_sizes.node()['value'] if split_val.dim() > 0: return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) split_size = sym_help._get_const(split_size_or_sizes, 'i', 'split_size') size = self.type().sizes()[dim] splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: splits.append(leftover) splits = g.op("Constant", value_t=torch.tensor(splits)) return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
def cat(g, tensor_list, dim): if sym_help._is_packed_list(tensor_list): from torch.onnx.symbolic_opset9 import cat as cat_opset9 return cat_opset9(g, tensor_list, dim) else: dim = sym_help._get_const(dim, 'i', 'dim') return g.op("ConcatFromSequence", tensor_list, axis_i=dim)
def split(g, self, split_size_or_sizes, dim, _outputs=None): if not sym_help._is_split_static(split_size_or_sizes, _outputs): split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) if _outputs is None: return split_out # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. if sym_help._is_packed_list(split_size_or_sizes) and len( sym_help._unpack_list(split_size_or_sizes)) == _outputs: split_sizes = [ sym_help._unsqueeze_helper(g, v, [0]) for v in sym_help._unpack_list(split_size_or_sizes) ] start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) res = [] for i in range(_outputs): end = g.op( "Add", start, split_sizes[i] ) # split_sizes is a list of same length as _outputs res.append(g.op("Slice", self, start, end, axis)) start = end return res return [ g.op("SequenceAt", split_out, g.op("Constant", value_t=torch.tensor([i], dtype=torch.long))) for i in range(_outputs) ] else: return torch.onnx.symbolic_opset9.split(g, self, split_size_or_sizes, dim, _outputs)
def split(g, self, split_size_or_sizes, dim, _outputs=None): if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim) if _outputs is None: return split_out # Convert to multiple slice nodes iff number of splits and number of outputs are statically known. if (symbolic_helper._is_packed_list(split_size_or_sizes) and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs): split_sizes = [ symbolic_helper._unsqueeze_helper(g, v, [0]) for v in symbolic_helper._unpack_list(split_size_or_sizes) ] start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long)) axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)) res = [] for i in range(_outputs): end = g.op( "Add", start, split_sizes[i] ) # split_sizes is a list of same length as _outputs res.append(g.op("Slice", self, start, end, axis)) start = end return res return [ g.op( "SequenceAt", split_out, g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)), ) for i in range(_outputs) ] split_val = split_size_or_sizes.node()["value"] if split_val.dim() > 0: return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs) split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") size = symbolic_helper._get_tensor_dim_size(self, dim) if size is None: if _outputs is not None: size = split_size * _outputs else: raise RuntimeError("Unknown dimension size not supported") splits = [split_size] * (size // split_size) leftover = size % split_size if leftover: splits.append(leftover) splits = g.op("Constant", value_t=torch.tensor(splits)) return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
def stack(g, tensor_list, dim): if sym_help._is_packed_list(tensor_list): from torch.onnx.symbolic_opset9 import stack as stack_opset9 return stack_opset9(g, tensor_list, dim) else: dim = sym_help._get_const(dim, "i", "dim") return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1)
def repeat(g, self, repeats): if not sym_help._is_value(repeats): repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) if sym_help._is_packed_list(repeats): repeat_size_len = len(sym_help._unpack_list(repeats)) else: const_repeats = sym_help._maybe_get_const(repeats, 'is') repeat_size_len = len(const_repeats) if self.isCompleteTensor(): sizes = self.type().sizes() diff_dims = repeat_size_len - len(sizes) if diff_dims > 0: self = sym_opset9.view(g, self, [1] * diff_dims + sizes) return g.op("Tile", self, repeats)
def _prepare_onnx_paddings(g, input, pad): if ( not symbolic_helper._is_packed_list(pad) and symbolic_helper._is_list(pad) and symbolic_helper._is_scalar_list(pad) ): pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1) # The desired order of paddings is # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. # n is the dimension of input. # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0]))) # Set extension = [0] * (dim * 2 - len(pad)) rank = symbolic_helper._get_tensor_rank(input) if rank is None: rank = g.op("Size", g.op("Shape", input)) else: rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64)) extension = g.op( "Sub", g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), pad_len, ) # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ] # Currently ONNX only supports int64 type for Pad pad = g.op("Cast", pad, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"]) paddings = g.op( "Concat", pad, g.op( "ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64) ), axis_i=0, ) # Reshape and reverse order and collate first beginnings and then ends # paddings = [[..., 0, dim_n-1_begin, dim_n_begin], # [..., 0, dim_n-1_end, dim_n_end]] # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end] paddings = symbolic_helper._reshape_helper( g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2])) ) paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0]) paddings = symbolic_helper._reshape_helper( g, paddings, g.op("Constant", value_t=torch.tensor([-1])) ) padding_c = g.op( "Cast", paddings, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"] ) return padding_c
def index(g, self, index): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("index", self, index, overload_name="Tensor") if symbolic_helper._is_packed_list(index): indices = symbolic_helper._unpack_list(index) else: indices = [index] # Handle single mask index. if len(indices) == 1: index = indices[0] if not symbolic_helper._is_none(index) and ( index.type().scalarType() == "Bool" or index.type().scalarType() == "Byte" ): index = opset9.nonzero(g, index) return g.op("GatherND", self, index) return opset9.index(g, self, index)
def index(g, self, index): if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: return g.op("ATen", self, index, operator_s="index") if sym_help._is_packed_list(index): indices = sym_help._unpack_list(index) else: indices = [index] # Handle single mask index. if len(indices) == 1: index = indices[0] if not sym_help._is_none(index) and (index.type().scalarType() == "Bool" or index.type().scalarType() == "Byte"): from torch.onnx.symbolic_opset9 import nonzero index = nonzero(g, index) return g.op("GatherND", self, index) from torch.onnx.symbolic_opset9 import index as index_opset9 return index_opset9(g, self, index)
def index(g, self, index): if sym_help.is_caffe2_aten_fallback(): return g.at("index", self, index, overload_name="Tensor") if sym_help._is_packed_list(index): indices = sym_help._unpack_list(index) else: indices = [index] # Handle single mask index. if len(indices) == 1: index = indices[0] if not sym_help._is_none(index) and ( index.type().scalarType() == "Bool" or index.type().scalarType() == "Byte"): from torch.onnx.symbolic_opset9 import nonzero index = nonzero(g, index) return g.op("GatherND", self, index) from torch.onnx.symbolic_opset9 import index as index_opset9 return index_opset9(g, self, index)
def index_put(g, self, indices_list_value, values, accumulate=False): if sym_help._is_packed_list(indices_list_value): indices_list = sym_help._unpack_list(indices_list_value) else: indices_list = [indices_list_value] if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: args = [self] + indices_list + [values, accumulate] return g.op("ATen", *args, operator_s='index_put') from torch.onnx.symbolic_opset9 import add, expand accumulate = sym_help._parse_arg(accumulate, 'b') if len(indices_list) == 0: return values index = indices_list[0] if len(indices_list) > 1: for ind in indices_list[1:]: index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ sym_help._unsqueeze_helper( g, expand(g, ind, broadcast_index_shape, None), [-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains boolean inputs # # index_put -> masked_fill # * input index contains single tensor of Bool type (e.g.: %24 <- %23). # * input value contains single element (e.g.: %18). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # # index_put -> masked_scatter # * input index contains single tensor of Bool type (e.g.: %32 <- %31). # * input value contains multiple elements (e.g.: %28). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) bool_inp = index if bool_inp.type() is not None and bool_inp.type().scalarType( ) == 'Bool': rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: from torch.onnx.symbolic_opset9 import masked_fill return masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = sym_help._unsqueeze_helper(g, index, [-1]) sub_data_shape = sym_help._slice_helper(g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) # Check if values is a singular value and expand accordingly rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: values = expand(g, values, values_shape, None) values = g.op("Reshape", values, values_shape) dtype = self.type().scalarType() if dtype is not None and dtype != values.type().scalarType(): values = g.op("Cast", values, to_i=sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_onnx.index( sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] if accumulate: zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def __interpolate(g, input, size, scale_factor, mode, align_corners): mode = sym_help._maybe_get_const(mode, 's') if 'linear' in mode: mode = 'linear' if 'cubic' in mode: mode = 'cubic' align_corners = sym_help._maybe_get_const(align_corners, 'b') align_corners = False if not isinstance(align_corners, bool) else align_corners coordinate_transformation_mode = "asymmetric" if mode == "nearest" \ else "align_corners" if align_corners else "pytorch_half_pixel" # roi only takes effect whith coordinate_transformation_mode="tf_crop_and_resize" roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) if not sym_help._is_none(size): input_size = g.op("Shape", input) input_size = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0]) # in some cases size is not a packed list but size is a scalar # We need to also verify that (sym_help._maybe_get_const(size, 't').dim() == 0) # but this information is not always available. Try to get the dim, # and if not assume that it is not a scalar. try: is_scalar = not sym_help._is_packed_list(size) and ( (sym_help._maybe_get_const(size, 't').dim() == 0)) except AttributeError: is_scalar = not sym_help._is_packed_list(size) if not is_scalar: warnings.warn( "Cannot verify if the output_size is a scalar while exporting interpolate. Assuming that it is not a scalar." ) if is_scalar: if not input.type().dim(): return sym_help._unimplemented( "interpolate (with a scalar output_size)", "missing input shape (try giving an array of output_size values)" ) size = unsqueeze(g, size, 0) size = [size for i in range(input.type().dim() - 2)] size = g.op("Concat", *size, axis_i=0) size = g.op("Cast", size, to_i=sym_help.cast_pytorch_to_onnx['Long']) size = g.op("Concat", input_size, size, axis_i=0) scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) return g.op( "Resize", input, roi, scales, size, coordinate_transformation_mode_s=coordinate_transformation_mode, cubic_coeff_a_f=-0.75, # only valid when mode="cubic" mode_s=mode, # nearest, linear, or cubic nearest_mode_s="floor") else: # if not sym_help._is_none(scales) if not input.type().dim(): return sym_help._unimplemented("interpolate (with scales)", "missing input shape") scales = sym_help._interpolate_get_scales(g, scale_factor, input.type().dim()) return g.op( "Resize", input, roi, scales, coordinate_transformation_mode_s=coordinate_transformation_mode, cubic_coeff_a_f=-0.75, # only valid when mode="cubic" mode_s=mode, # nearest, linear, or cubic nearest_mode_s="floor") # only valid when mode="nearest"
def stack(g, tensor_list, dim): if symbolic_helper._is_packed_list(tensor_list): return opset9.stack(g, tensor_list, dim) else: dim = symbolic_helper._get_const(dim, "i", "dim") return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1)
def index_put(g, self, indices_list_value, values, accumulate=False): if symbolic_helper._is_packed_list(indices_list_value): indices_list = symbolic_helper._unpack_list(indices_list_value) else: indices_list = [indices_list_value] if symbolic_helper.is_caffe2_aten_fallback(): args = [self] + indices_list + [values, accumulate] return g.at("index_put", *args) accumulate = symbolic_helper._parse_arg(accumulate, "b") if len(indices_list) == 0: return values if len(indices_list) > 1: for idx_ in range(len(indices_list)): if indices_list[idx_].type().scalarType() == "Bool": # type: ignore[attr-defined] # TODO(justinchuby): Remove type ignore after #81112 is checked in. indices_list[idx_] = g.op("NonZero", indices_list[idx_]) index = indices_list[0] for ind in indices_list[1:]: index = opset9.add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ symbolic_helper._unsqueeze_helper( g, opset9.expand(g, ind, broadcast_index_shape, None), [-1] ) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains a single boolean input. # # index_put -> masked_fill # * input index contains single tensor of Bool type (e.g.: %24 <- %23). # * input value contains single element (e.g.: %18). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # # index_put -> masked_scatter # * input index contains single tensor of Bool type (e.g.: %32 <- %31). # * input value contains multiple elements (e.g.: %28). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) index = indices_list[0] bool_inp = index if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool": # type: ignore[attr-defined] # TODO(justinchuby): Remove type ignore after #81112 is checked in. rank = symbolic_helper._get_tensor_rank(values) if rank is not None and rank == 0: return opset9.masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = symbolic_helper._unsqueeze_helper(g, index, [-1]) sub_data_shape = symbolic_helper._slice_helper( g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize] ) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) # Check if values is a singular value and expand accordingly rank = symbolic_helper._get_tensor_rank(values) if rank is not None and rank == 0: values = opset9.expand(g, values, values_shape, None) values = symbolic_helper._reshape_helper(g, values, values_shape) dtype = self.type().scalarType() if dtype is not None and dtype != values.type().scalarType(): values = g.op("Cast", values, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]) dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype] if accumulate: zeros = g.op( "ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype), ) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result