def _index_fill_reshape_helper(g, self, dim, index): # 1. reshape index => [1, ..., 1, dim, 1, ..., 1] # 2. expand index => [..., dim, ...], same shape as self except for dim. # 3. expand value as well. # 4. apply onnx::scatter. from torch.onnx.symbolic_opset9 import expand if _export_onnx_opset_version <= 10: from torch.onnx.symbolic_opset9 import scatter else: from torch.onnx.symbolic_opset11 import scatter if self.type().dim() is None: return _unimplemented("index_fill", "input rank not accesible") self_dim = self.type().dim() dim_value = _parse_arg(dim, 'i') unsqueezed_index = g.op( "Unsqueeze", index, axes_i=[i for i in range(self_dim) if i != dim_value]) expanded_index_shape = scatter(g, g.op("Shape", self), 0, g.op("Unsqueeze", dim, axes_i=[0]), g.op("Shape", index)) expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None) return expanded_index_shape, expanded_index
def index_fill(g, self, dim, index, value): dim_value = sym_help._parse_arg(dim, 'i') if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: return g.op("ATen", self, index, value, dim_i=dim_value, operator_s="index_fill") expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index) value = sym_help._maybe_get_scalar(value) value = sym_help._if_scalar_type_as(g, value, self) expanded_value = expand(g, value, expanded_index_shape, None) return scatter(g, self, dim, expanded_index, expanded_value)
def chunk(g, self, chunks, dim): # Calculate chunk size for dynamic chunk dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0) chunk_size_s = g.op("Sub", chunks, g.op("Constant", value_t=torch.tensor([1], dtype=torch.long))) chunk_size = g.op("Div", g.op("Add", dim_size, chunk_size_s), chunks) # Create splits vector chunk_vec = [expand(g, chunk_size, chunk_size_s, None), g.op("Sub", dim_size, g.op("Mul", chunk_size, chunk_size_s))] chunk_vec = g.op("Concat", *chunk_vec, axis_i=0) return split(g, self, chunk_vec, dim)
def index_put(g, self, indices_list_value, values, accumulate=False): indices_list = sym_help._unpack_list(indices_list_value) if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: args = [self] + indices_list + [values, accumulate] return g.op("ATen", *args, operator_s='index_put') from torch.onnx.symbolic_opset9 import add, expand accumulate = sym_help._parse_arg(accumulate, 'b') index = indices_list[0] if len(indices_list) > 1: for ind in indices_list[1:]: index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ g.op("Unsqueeze", expand(g, ind, broadcast_index_shape, None), axes_i=[-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: broadcast_index_shape = g.op("Shape", index) index = g.op("Unsqueeze", index, axes_i=[-1]) sub_data_shape = sym_help._slice_helper(g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) values = g.op("Reshape", values, values_shape) if accumulate: dtype = self.type().scalarType() dtype = sym_help.scalar_type_to_onnx.index( sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def index_fill(g, self, dim, index, value): dim_value = sym_help._parse_arg(dim, "i") if sym_help.is_caffe2_aten_fallback(): return g.at("index_fill", self, index, value, overload_name="int_Scalar", dim_i=dim_value) expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper( g, self, dim, index) value = sym_help._maybe_get_scalar(value) value = sym_help._if_scalar_type_as(g, value, self) expanded_value = expand(g, value, expanded_index_shape, None) return scatter(g, self, dim, expanded_index, expanded_value)
def index_put(g, self, indices_list_value, values, accumulate=False): if sym_help._is_packed_list(indices_list_value): indices_list = sym_help._unpack_list(indices_list_value) else: indices_list = [indices_list_value] if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: args = [self] + indices_list + [values, accumulate] return g.op("ATen", *args, operator_s='index_put') from torch.onnx.symbolic_opset9 import add, expand accumulate = sym_help._parse_arg(accumulate, 'b') if len(indices_list) == 0: return values index = indices_list[0] if len(indices_list) > 1: for ind in indices_list[1:]: index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ sym_help._unsqueeze_helper( g, expand(g, ind, broadcast_index_shape, None), [-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains boolean inputs # # index_put -> masked_fill # * input index contains single tensor of Bool type (e.g.: %24 <- %23). # * input value contains single element (e.g.: %18). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # # index_put -> masked_scatter # * input index contains single tensor of Bool type (e.g.: %32 <- %31). # * input value contains multiple elements (e.g.: %28). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) bool_inp = index if bool_inp.type() is not None and bool_inp.type().scalarType( ) == 'Bool': rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: from torch.onnx.symbolic_opset9 import masked_fill return masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = sym_help._unsqueeze_helper(g, index, [-1]) sub_data_shape = sym_help._slice_helper(g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) # Check if values is a singular value and expand accordingly rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: values = expand(g, values, values_shape, None) values = g.op("Reshape", values, values_shape) dtype = self.type().scalarType() if dtype is not None and dtype != values.type().scalarType(): values = g.op("Cast", values, to_i=sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_onnx.index( sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] if accumulate: zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def repeat_interleave(g, self, repeats, dim=None): from torch.onnx.symbolic_opset9 import reshape input = self final_dim = dim # if dim is None flatten # By default, use the flattened input array, and return a flat output array if sym_help._is_none(dim): input = reshape(g, self, g.op("Constant", value_t=torch.tensor([-1]))) dim = 0 else: dim = sym_help._maybe_get_scalar(dim) repeats_dim = sym_help._get_tensor_rank(repeats) repeats_sizes = sym_help._get_tensor_sizes(repeats) input_sizes = sym_help._get_tensor_sizes(input) if repeats_dim is None: raise RuntimeError( 'Unsupported: ONNX export of repeat_interleave for unknown ' 'repeats rank.') if repeats_sizes is None: raise RuntimeError( 'Unsupported: ONNX export of repeat_interleave for unknown ' 'repeats size.') if input_sizes is None: raise RuntimeError( 'Unsupported: ONNX export of repeat_interleave for unknown ' 'input size.') # Handle cases where dim is negative if dim < 0: dim += len(input_sizes) output_sizes = input_sizes.copy() perm_i = [0] for idx, input_size in enumerate(input_sizes): perm_i.append(idx + 1) if input_size is None: output_sizes[idx], input_sizes[idx] = 0, -1 perm_i[0], perm_i[dim] = perm_i[dim], perm_i[0] # Cases when repeats is a single value tensor and dim has unknown input size if (repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1)) and output_sizes[dim] == 0: if not sym_help._is_tensor(repeats): repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) reps = sym_help._size_helper(g, input, dim) reps = unsqueeze(g, reps, 0) repeats = g.op("Expand", repeats, reps) # There are cases when the repeats are 1-d tensor with multiple repeats, but dim # provided along one of the dynamic axes provided. A simple example would be # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 # Now, repeat interleaving can be performed in pytorch when the value of * matches # with the number of elements in repeat, for example if * -> 2, number of repeats # should be 2 as well. else: return torch.onnx.symbolic_opset9.repeat_interleave( g, self, repeats, final_dim) reps_like = g.op("ConstantOfShape", g.op("Shape", repeats), value_t=torch.tensor([1], dtype=torch.long)) r_splits = split(g, repeats, reps_like, 0) i_splits = split(g, input, reps_like, dim) output_sizes[dim], input_sizes[dim] = -1, 1 # Create a loop to iterate over each value along the dimension # and perform individual interleaving using the repeats tensor # Loop is of the following pattern # input (trip_count, cond) # int trip_count = ...; # bool cond = ...; # for (int i=0; i < trip_count && cond; ++i) { # cond = ...; # } # Loop conditions loop_condition = g.op("Constant", value_t=torch.tensor(1)) loop_condition = g.op("Cast", loop_condition, to_i=9) loop_len = reps loop = g.op("Loop", loop_len, loop_condition) # Loop inputs loop_block = _add_block(loop.node()) block_input_iter = _add_input_to_block(loop_block) cond = _add_input_to_block(loop_block) r_split = loop_block.op("SequenceAt", r_splits, block_input_iter) i_split = loop_block.op("SequenceAt", i_splits, block_input_iter) i_split = unsqueeze(loop_block, i_split, dim + 1) r_concat = [ loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])), r_split, loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:])) ] r_concat = loop_block.op("Concat", *r_concat, axis_i=0) i_split = expand(loop_block, i_split, r_concat, None) i_split = reshape(loop_block, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes))) # Loop outputs cond_out = loop_block.op("Cast", loop_condition, to_i=9) _add_output_to_block(loop_block, cond_out) _add_output_to_block(loop_block, i_split) loop_out = loop.node().output() # In this loop, the outputs are scan outputs and are concatenated along # the zero'th dimension (by default). In order to avoid this and concatenate # along the dimension provided, some post-processing is required loop_out = g.op("Transpose", loop_out, perm_i=perm_i) return reshape(g, loop_out, g.op("Constant", value_t=torch.LongTensor(output_sizes)))
def repeat_interleave(g, self, repeats, dim=None, output_size=None): input = self final_dim = dim # if dim is None flatten # By default, use the flattened input array, and return a flat output array if sym_help._is_none(dim): input = sym_help._reshape_helper( g, self, g.op("Constant", value_t=torch.tensor([-1]))) dim = 0 else: dim = sym_help._maybe_get_scalar(dim) repeats_dim = sym_help._get_tensor_rank(repeats) repeats_sizes = sym_help._get_tensor_sizes(repeats) input_sizes = sym_help._get_tensor_sizes(input) if repeats_dim is None: raise RuntimeError( "Unsupported: ONNX export of repeat_interleave for unknown " "repeats rank.") if repeats_sizes is None: raise RuntimeError( "Unsupported: ONNX export of repeat_interleave for unknown " "repeats size.") if input_sizes is None: raise RuntimeError( "Unsupported: ONNX export of repeat_interleave for unknown " "input size.") # Handle cases where dim is negative if dim < 0: dim += len(input_sizes) output_sizes = input_sizes.copy() for idx, input_size in enumerate(input_sizes): if input_size is None: output_sizes[idx], input_sizes[idx] = 0, -1 print(output_sizes, input_sizes) cond_dynamic_repeats = (repeats_dim == 1 and repeats_sizes[0] is None) # If input size is dynamic or repeats vector is dynamic if output_sizes[dim] == 0 or cond_dynamic_repeats: reps = sym_help._size_helper(g, input, dim) reps = unsqueeze(g, reps, 0) # Check if repeats vector is a single integer value # or a single dimension tensor with non-dynamic values if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): if not sym_help._is_tensor(repeats): repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) repeats = g.op("Expand", repeats, reps) # Check if repeats is dynamic # As repeats is dynamic, we use a where node as a substitute for the if statement # If repests_dim = 1, expand repeats otherwise use original tensor elif cond_dynamic_repeats: repeat_dim = sym_help._size_helper( g, repeats, g.op("Constant", value_t=torch.LongTensor([0]))) repeat_cond = g.op("Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1]))) repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats) # There are cases when the repeats are 1-d tensor with multiple repeats, but dim # provided along one of the dynamic axes provided. A simple example would be # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2 # Now, repeat interleaving can be performed in pytorch when the value of * matches # with the number of elements in repeat, for example if * -> 2, number of repeats # should be 2 as well. else: return torch.onnx.symbolic_opset9.repeat_interleave( g, self, repeats, final_dim) reps_like = g.op("ConstantOfShape", g.op("Shape", repeats), value_t=torch.tensor([1], dtype=torch.long)) r_splits = split(g, repeats, reps_like, 0) i_splits = split(g, input, reps_like, dim) output_sizes[dim], input_sizes[dim] = -1, 1 # Create a loop to iterate over each value along the dimension # and perform individual interleaving using the repeats tensor # Loop is of the following pattern # input (trip_count, cond) # int trip_count = ...; # bool cond = ...; # for (int i=0; i < trip_count && cond; ++i) { # cond = ...; # } # Loop conditions loop_condition = g.op("Constant", value_t=torch.tensor(1)) loop_condition = g.op("Cast", loop_condition, to_i=9) loop_len = reps # Create an empty sequence to store final expansions final_splits = g.op("SequenceEmpty") loop = g.op("Loop", loop_len, loop_condition, final_splits) # Loop inputs loop_block = _add_block(loop.node()) block_input_iter = _add_input_to_block(loop_block) cond = _add_input_to_block(loop_block) final_splits = _add_input_to_block(loop_block) r_split = loop_block.op("SequenceAt", r_splits, block_input_iter) i_split = loop_block.op("SequenceAt", i_splits, block_input_iter) i_split = unsqueeze(loop_block, i_split, dim + 1) r_concat = [ loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[:dim + 1])), r_split, loop_block.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1:])) ] r_concat = loop_block.op("Concat", *r_concat, axis_i=0) i_split = expand(loop_block, i_split, r_concat, None) i_split = sym_help._reshape_helper( loop_block, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes))) final_splits = loop_block.op("SequenceInsert", final_splits, i_split) # Loop outputs cond_out = loop_block.op("Cast", loop_condition, to_i=9) _add_output_to_block(loop_block, cond_out) _add_output_to_block(loop_block, final_splits) loop_out = loop.node().output() loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim) return loop_out
def index_put(g, self, indices_list_value, values, accumulate=False): indices_list = sym_help._unpack_list(indices_list_value) if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: args = [self] + indices_list + [values, accumulate] return g.op("ATen", *args, operator_s='index_put') from torch.onnx.symbolic_opset9 import add, expand accumulate = sym_help._parse_arg(accumulate, 'b') index = indices_list[0] if len(indices_list) > 1: for ind in indices_list[1:]: index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ sym_help._unsqueeze_helper( g, expand(g, ind, broadcast_index_shape, None), [-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains boolean inputs # # index_put -> masked_fill # # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %6 : None = prim::Constant() # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %8 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::ne(%mask, %some_const) # %26 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]() # %27 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %11 : Device = prim::Constant[value="cpu"]() # %12 : None = prim::Constant() # %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %15 : None = prim::Constant() # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %22 : int[] = prim::Constant[value=[-1]]() # %23 : Tensor = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %3 : Tensor = onnx::Equal(%0, %some_const) # %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3) # %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4) # %19 : Tensor = onnx::Cast[to=9](%12) # %20 : Tensor = onnx::Constant[value={1}]() # %21 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = onnx::Where(%19, %20, %0) # return (%21) # # index_put -> masked_scatter # # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %6 : None = prim::Constant() # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %34 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]() # %35 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %18 : Device = prim::Constant[value="cpu"]() # %19 : None = prim::Constant() # %36 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %22 : None = prim::Constant() # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Tensor = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) # # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %3 : Float(8, strides=[1], requires_grad=0, device=cpu) # = onnx::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %4 : Tensor = onnx::Equal(%0, %some_const) # %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4) # %13 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%5) # %19 : Tensor = onnx::Shape(%0) # %20 : Tensor = onnx::Expand(%13, %19) # %21 : Tensor = onnx::NonZero(%20) # %22 : Tensor = onnx::Transpose[perm=[1, 0]](%21) # %23 : Tensor = onnx::Constant[value={-1}]() # %24 : Tensor = onnx::Reshape(%3, %23) # %25 : Tensor = onnx::Shape(%22) # %27 : Tensor = onnx::Constant[value={0}]() # %28 : Tensor = onnx::Gather[axis=0](%25, %27) # %29 : Tensor = onnx::Constant[value={0}]() # %30 : Tensor = onnx::Unsqueeze[axes=[0]](%29) # %31 : Tensor = onnx::Unsqueeze[axes=[0]](%28) # %32 : Tensor = onnx::Constant[value={0}]() # %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32) # %34 : Tensor = onnx::Slice(%24, %30, %31, %33) # %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = onnx::ScatterND(%0, %22, %34) # return (%35) bool_inp = list(index.node().inputs())[0] if bool_inp.type() is not None and bool_inp.type().scalarType( ) == 'Bool': rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: from torch.onnx.symbolic_opset9 import masked_fill return masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = sym_help._unsqueeze_helper(g, index, [-1]) sub_data_shape = sym_help._slice_helper(g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) values = g.op("Reshape", values, values_shape) if accumulate: dtype = self.type().scalarType() dtype = sym_help.scalar_type_to_onnx.index( sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def index_put(g, self, indices_list_value, values, accumulate=False): if symbolic_helper._is_packed_list(indices_list_value): indices_list = symbolic_helper._unpack_list(indices_list_value) else: indices_list = [indices_list_value] if symbolic_helper.is_caffe2_aten_fallback(): args = [self] + indices_list + [values, accumulate] return g.at("index_put", *args) accumulate = symbolic_helper._parse_arg(accumulate, "b") if len(indices_list) == 0: return values if len(indices_list) > 1: for idx_ in range(len(indices_list)): if indices_list[idx_].type().scalarType() == "Bool": # type: ignore[attr-defined] # TODO(justinchuby): Remove type ignore after #81112 is checked in. indices_list[idx_] = g.op("NonZero", indices_list[idx_]) index = indices_list[0] for ind in indices_list[1:]: index = opset9.add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ symbolic_helper._unsqueeze_helper( g, opset9.expand(g, ind, broadcast_index_shape, None), [-1] ) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains a single boolean input. # # index_put -> masked_fill # * input index contains single tensor of Bool type (e.g.: %24 <- %23). # * input value contains single element (e.g.: %18). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # # index_put -> masked_scatter # * input index contains single tensor of Bool type (e.g.: %32 <- %31). # * input value contains multiple elements (e.g.: %28). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) index = indices_list[0] bool_inp = index if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool": # type: ignore[attr-defined] # TODO(justinchuby): Remove type ignore after #81112 is checked in. rank = symbolic_helper._get_tensor_rank(values) if rank is not None and rank == 0: return opset9.masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = symbolic_helper._unsqueeze_helper(g, index, [-1]) sub_data_shape = symbolic_helper._slice_helper( g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize] ) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) # Check if values is a singular value and expand accordingly rank = symbolic_helper._get_tensor_rank(values) if rank is not None and rank == 0: values = opset9.expand(g, values, values_shape, None) values = symbolic_helper._reshape_helper(g, values, values_shape) dtype = self.type().scalarType() if dtype is not None and dtype != values.type().scalarType(): values = g.op("Cast", values, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]) dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype] if accumulate: zeros = g.op( "ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype), ) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result