def index_copy(g, self, dim, index, source): dim_value = sym_help._parse_arg(dim, "i") if sym_help.is_caffe2_aten_fallback(): return g.at("index_copy", self, index, source, dim_i=dim_value) expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper( g, self, dim, index) return scatter(g, self, dim, expanded_index, source)
def scatter_add(g, self, dim, index, src): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("scatter", self, dim, index, src, overload_name="src") src_type = src.type().scalarType() src_sizes = symbolic_helper._get_tensor_sizes(src) index_sizes = symbolic_helper._get_tensor_sizes(index) if src_sizes != index_sizes: return symbolic_helper._unimplemented( "scatter_add", f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})", ) src = symbolic_helper._maybe_get_scalar(src) if symbolic_helper._is_value(src): return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add") else: # Check if scalar "src" has same type as self (PyTorch allows different # type for scalar src (but not when src is tensor)). If not, insert Cast node. if self.type().scalarType() != src_type: src = g.op( "Cast", src, to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()], ) return g.op( "ScatterElements", self, index, src, axis_i=dim, reduction_s="add", )
def _dim_arange(g, like, dim): like_shape = g.op("Shape", like) stop = g.op( "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 ) if symbolic_helper.is_caffe2_aten_fallback(): return g.op("_caffe2::Range", stop) return arange(g, stop, 4, None, None, None)
def index_fill(g, self, dim, index, value): dim_value = sym_help._parse_arg(dim, "i") if sym_help.is_caffe2_aten_fallback(): return g.at("index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value) expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper( g, self, dim, index) value = sym_help._maybe_get_scalar(value) value = sym_help._if_scalar_type_as(g, value, self) expanded_value = expand(g, value, expanded_index_shape, None) return scatter(g, self, dim, expanded_index, expanded_value)
def index(g, self, index): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("index", self, index, overload_name="Tensor") if symbolic_helper._is_packed_list(index): indices = symbolic_helper._unpack_list(index) else: indices = [index] # Handle single mask index. if len(indices) == 1: index = indices[0] if not symbolic_helper._is_none(index) and ( index.type().scalarType() == "Bool" or index.type().scalarType() == "Byte" ): index = opset9.nonzero(g, index) return g.op("GatherND", self, index) return opset9.index(g, self, index)
def scatter(g, self, dim, index, src): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("scatter", self, dim, index, src, overload_name="src") src_type = src.type().scalarType() src = symbolic_helper._maybe_get_scalar(src) if symbolic_helper._is_value(src): return g.op("ScatterElements", self, index, src, axis_i=dim) else: # Check if scalar "src" has same type as self (PyTorch allows different # type for scalar src (but not when src is tensor)). If not, insert Cast node. if self.type().scalarType() != src_type: src = g.op( "Cast", src, to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()], ) return g.op( "ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim )
def index(g, self, index): if sym_help.is_caffe2_aten_fallback(): return g.at("index", self, index, overload_name="Tensor") if sym_help._is_packed_list(index): indices = sym_help._unpack_list(index) else: indices = [index] # Handle single mask index. if len(indices) == 1: index = indices[0] if not sym_help._is_none(index) and ( index.type().scalarType() == "Bool" or index.type().scalarType() == "Byte"): from torch.onnx.symbolic_opset9 import nonzero index = nonzero(g, index) return g.op("GatherND", self, index) from torch.onnx.symbolic_opset9 import index as index_opset9 return index_opset9(g, self, index)
def gather(g, self, dim, index, sparse_grad=False): if symbolic_helper._maybe_get_const(sparse_grad, "i"): return symbolic_helper._unimplemented("gather", "sparse_grad == True") if symbolic_helper.is_caffe2_aten_fallback(): return g.at("gather", self, dim, index, sparse_grad) return g.op("GatherElements", self, index, axis_i=dim)
def index_put(g, self, indices_list_value, values, accumulate=False): if symbolic_helper._is_packed_list(indices_list_value): indices_list = symbolic_helper._unpack_list(indices_list_value) else: indices_list = [indices_list_value] if symbolic_helper.is_caffe2_aten_fallback(): args = [self] + indices_list + [values, accumulate] return g.at("index_put", *args) accumulate = symbolic_helper._parse_arg(accumulate, "b") if len(indices_list) == 0: return values if len(indices_list) > 1: for idx_ in range(len(indices_list)): if indices_list[idx_].type().scalarType() == "Bool": indices_list[idx_] = g.op("NonZero", indices_list[idx_]) index = indices_list[0] for ind in indices_list[1:]: index = opset9.add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ symbolic_helper._unsqueeze_helper( g, opset9.expand(g, ind, broadcast_index_shape, None), [-1] ) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains a single boolean input. # # index_put -> masked_fill # * input index contains single tensor of Bool type (e.g.: %24 <- %23). # * input value contains single element (e.g.: %18). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # # index_put -> masked_scatter # * input index contains single tensor of Bool type (e.g.: %32 <- %31). # * input value contains multiple elements (e.g.: %28). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) index = indices_list[0] bool_inp = index if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool": rank = symbolic_helper._get_tensor_rank(values) if rank is not None and rank == 0: return opset9.masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = symbolic_helper._unsqueeze_helper(g, index, [-1]) sub_data_shape = symbolic_helper._slice_helper( g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize] ) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) # Check if values is a singular value and expand accordingly rank = symbolic_helper._get_tensor_rank(values) if rank is not None and rank == 0: values = opset9.expand(g, values, values_shape, None) values = symbolic_helper._reshape_helper(g, values, values_shape) dtype = self.type().scalarType() if dtype is not None and dtype != values.type().scalarType(): values = g.op("Cast", values, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]) dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype] if accumulate: zeros = g.op( "ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype), ) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def unfold(g, input, dimension, size, step): const_size = sym_help._maybe_get_const(size, "i") const_step = sym_help._maybe_get_const(step, "i") if not sym_help._is_value(const_size) and not sym_help._is_value( const_step): from torch.onnx.symbolic_opset9 import unfold as _unfold return _unfold(g, input, dimension, const_size, const_step) if sym_help.is_caffe2_aten_fallback(): return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step) sizedim = sym_help._get_tensor_dim_size(input, dimension) if sizedim is not None: low_start = g.op("Constant", value_t=torch.tensor(0)) low_end = g.op("Constant", value_t=torch.tensor(sizedim)) hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1)) low_indices = g.op("Range", low_start, low_end, step) hi_indices = g.op("Range", size, hi_end, step) low_size = sym_help._size_helper( g, low_indices, g.op("Constant", value_t=torch.tensor(0))) hi_size = sym_help._size_helper( g, hi_indices, g.op("Constant", value_t=torch.tensor(0))) ndim = sym_help._get_tensor_rank(input) perm = list(range(0, ndim)) perm.append(perm.pop(dimension)) unsqueeze_list = [] loop_condition = g.op("Constant", value_t=torch.tensor(1)) loop_condition = g.op("Cast", loop_condition, to_i=9) loop_len = g.op("Min", low_size, hi_size) loop = g.op("Loop", loop_len, loop_condition) loop_block = _add_block(loop.node()) block_input_iter = _add_input_to_block(loop_block) cond = _add_input_to_block(loop_block) starts = loop_block.op("Gather", low_indices, block_input_iter) ends = loop_block.op("Gather", hi_indices, block_input_iter) axes = loop_block.op("Constant", value_t=torch.tensor([2])) starts = sym_help._unsqueeze_helper(loop_block, starts, [0]) ends = sym_help._unsqueeze_helper(loop_block, ends, [0]) stack = loop_block.op("Slice", input, starts, ends, axes) unsqueeze = sym_help._unsqueeze_helper( loop_block, loop_block.op("Transpose", stack, perm_i=perm), [dimension]) unsqueeze_list.append(unsqueeze) concat = loop_block.op("Concat", *unsqueeze_list, axis_i=0) cond_out = loop_block.op("Cast", loop_condition, to_i=9) _add_output_to_block(loop_block, cond_out) _add_output_to_block(loop_block, concat) loop_output = loop.node().output() perm = [0, 1, 2, 3, 4] perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0] transpose = g.op("Transpose", loop_output, perm_i=perm) squeeze = sym_help._squeeze_helper(g, transpose, [0]) return squeeze else: return _unimplemented("Unfold", "input size not accessible")