def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduction): from torch.onnx.symbolic_opset9 import sigmoid, log, sub, neg, mul, add p = g.op("Constant", value_t=torch.tensor([1])) sig_x = sigmoid(g, input) log_sig_x = log(g, sig_x) sub_1_x = sub(g, p, sig_x) sub_1_y = sub(g, p, target) log_1_x = log(g, sub_1_x) if pos_weight is None or sym_help._is_none(pos_weight): output = neg( g, add(g, mul(g, target, log_sig_x), mul(g, sub_1_y, log_1_x))) else: output = neg( g, add(g, mul(g, mul(g, target, log_sig_x), pos_weight), mul(g, sub_1_y, log_1_x))) if weight is not None and not sym_help._is_none(weight): output = mul(g, weight, output) reduction = sym_help._maybe_get_const(reduction, 'i') if reduction == 0: return output elif reduction == 1: return g.op("ReduceMean", output) elif reduction == 2: return g.op("ReduceSum", output) else: return sym_help._onnx_unsupported( "binary_cross_entropy_with_logits with reduction other than none, mean, or sum" )
def index_put(g, self, indices_list_value, values, accumulate=False): indices_list = sym_help._unpack_list(indices_list_value) if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: args = [self] + indices_list + [values, accumulate] return g.op("ATen", *args, operator_s='index_put') from torch.onnx.symbolic_opset9 import add, expand accumulate = sym_help._parse_arg(accumulate, 'b') index = indices_list[0] if len(indices_list) > 1: for ind in indices_list[1:]: index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ g.op("Unsqueeze", expand(g, ind, broadcast_index_shape, None), axes_i=[-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: broadcast_index_shape = g.op("Shape", index) index = g.op("Unsqueeze", index, axes_i=[-1]) sub_data_shape = sym_help._slice_helper(g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) values = g.op("Reshape", values, values_shape) if accumulate: dtype = self.type().scalarType() dtype = sym_help.scalar_type_to_onnx.index( sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def add(g, x, y, op_scale, op_zero_point): x, _, _, _ = sym_help.dequantize_helper(g, x) y, _, _, _ = sym_help.dequantize_helper(g, y) output = add(g, x, y) return sym_help.quantize_helper(g, output, op_scale, op_zero_point)
def group_norm_symbolic(g, input, num_groups, weight, bias, eps, cudnn_enabled): from torch.onnx.symbolic_opset9 import reshape, mul, add, reshape_as channels_num = input.type().sizes()[1] if num_groups == channels_num: output = g.op('InstanceNormalization', input, weight, bias, epsilon_f=eps) else: # Reshape from [n, g * cg, h, w] to [1, n * g, cg * h, w]. x = reshape(g, input, [0, num_groups, -1, 0]) x = reshape(g, x, [1, -1, 0, 0]) # Normalize channel-wise. x = g.op('MeanVarianceNormalization', x, axes_i=[2, 3]) # Reshape back. x = reshape_as(g, x, input) # Apply affine transform. x = mul(g, x, reshape(g, weight, [1, channels_num, 1, 1])) output = add(g, x, reshape(g, bias, [1, channels_num, 1, 1])) return output
def normal(g, loc, scale, seed): # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a # scale-location transformation of that distribution, which has mean μ and variance σ's square. If x is a sample # from a mean 0 and variance 1 distribution then # σx+μ # is a sample with mean μ and variance σ's square. result = mul(g, scale, g.op("RandomNormalLike", loc)) return add(g, result, loc)
def full(g, sizes, value, dtype, layout, device, pin_memory=False): const_value = symbolic_helper._maybe_get_const(value, "t") if symbolic_helper._is_value(const_value): tmp = zeros(g, sizes, dtype, layout, device) return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) else: dtype = symbolic_helper._get_const(dtype, "i", "dtype") return _constant_fill(g, sizes, dtype, const_value)
def add(g, x, y, op_scale, op_zero_point): x, _, _, _ = symbolic_helper.dequantize_helper(g, x) y, _, _, _ = symbolic_helper.dequantize_helper(g, y) output = opset9.add(g, x, y) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
def full(g, sizes, value, dtype, layout, device, pin_memory=False): const_value = sym_help._maybe_get_const(value, 't') if sym_help._is_value(const_value): tmp = zeros(g, sizes, dtype, layout, device) return sym_opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) else: dtype = sym_help._get_const(dtype, 'i', 'dtype') return _constant_fill(g, sizes, dtype, const_value)
def binary_cross_entropy_with_logits(g, input, target, weight, pos_weight, reduction): p = g.op("Constant", value_t=torch.tensor([1])) sig_x = opset9.sigmoid(g, input) log_sig_x = opset9.log(g, sig_x) sub_1_x = opset9.sub(g, p, sig_x) sub_1_y = opset9.sub(g, p, target) log_1_x = opset9.log(g, sub_1_x) if pos_weight is None or symbolic_helper._is_none(pos_weight): output = opset9.neg( g, opset9.add(g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x)), ) else: output = opset9.neg( g, opset9.add( g, opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight), opset9.mul(g, sub_1_y, log_1_x), ), ) if weight is not None and not symbolic_helper._is_none(weight): output = opset9.mul(g, weight, output) reduction = symbolic_helper._maybe_get_const(reduction, "i") if reduction == 0: return output elif reduction == 1: return g.op("ReduceMean", output, keepdims_i=0) elif reduction == 2: return g.op("ReduceSum", output, keepdims_i=0) else: return symbolic_helper._onnx_unsupported( "binary_cross_entropy_with_logits with reduction other than none, mean, or sum", input, )
def add(g, self, other, alpha=None): if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): tensor_list_node = other.node() if tensor_list_node.kind() != "prim::ListConstruct": return symbolic_helper._unimplemented( "add", "does not support adding dynamic tensor list to another" ) tensors = symbolic_helper._unpack_list(other) l = self for t in tensors: l = g.op("SequenceInsert", l, t) return l return opset9.add(g, self, other, alpha)
def addcmul_symbolic(g, self, tensor1, tensor2, value=1, out=None): from torch.onnx.symbolic_opset9 import add, mul if out is not None: sym_help._unimplemented("addcmul", "Out parameter is not supported for addcmul") x = mul(g, tensor1, tensor2) value = sym_help._maybe_get_scalar(value) if sym_help._scalar(value) != 1: value = sym_help._if_scalar_type_as(g, value, x) if not sym_help._is_value(value): value = g.op("Constant", value_t=torch.tensor(value, dtype=torch.float32)) x = mul(g, x, value) return add(g, self, x)
def index_put(g, self, indices_list_value, values, accumulate=False): if sym_help._is_packed_list(indices_list_value): indices_list = sym_help._unpack_list(indices_list_value) else: indices_list = [indices_list_value] if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: args = [self] + indices_list + [values, accumulate] return g.op("ATen", *args, operator_s='index_put') from torch.onnx.symbolic_opset9 import add, expand accumulate = sym_help._parse_arg(accumulate, 'b') if len(indices_list) == 0: return values index = indices_list[0] if len(indices_list) > 1: for ind in indices_list[1:]: index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ sym_help._unsqueeze_helper( g, expand(g, ind, broadcast_index_shape, None), [-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains boolean inputs # # index_put -> masked_fill # * input index contains single tensor of Bool type (e.g.: %24 <- %23). # * input value contains single element (e.g.: %18). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # # index_put -> masked_scatter # * input index contains single tensor of Bool type (e.g.: %32 <- %31). # * input value contains multiple elements (e.g.: %28). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) bool_inp = index if bool_inp.type() is not None and bool_inp.type().scalarType( ) == 'Bool': rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: from torch.onnx.symbolic_opset9 import masked_fill return masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = sym_help._unsqueeze_helper(g, index, [-1]) sub_data_shape = sym_help._slice_helper(g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) # Check if values is a singular value and expand accordingly rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: values = expand(g, values, values_shape, None) values = g.op("Reshape", values, values_shape) dtype = self.type().scalarType() if dtype is not None and dtype != values.type().scalarType(): values = g.op("Cast", values, to_i=sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_onnx.index( sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] if accumulate: zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def index_put(g, self, indices_list_value, values, accumulate=False): indices_list = sym_help._unpack_list(indices_list_value) if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: args = [self] + indices_list + [values, accumulate] return g.op("ATen", *args, operator_s='index_put') from torch.onnx.symbolic_opset9 import add, expand accumulate = sym_help._parse_arg(accumulate, 'b') index = indices_list[0] if len(indices_list) > 1: for ind in indices_list[1:]: index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ sym_help._unsqueeze_helper( g, expand(g, ind, broadcast_index_shape, None), [-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains boolean inputs # # index_put -> masked_fill # # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %6 : None = prim::Constant() # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %8 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::ne(%mask, %some_const) # %26 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]() # %27 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %11 : Device = prim::Constant[value="cpu"]() # %12 : None = prim::Constant() # %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %15 : None = prim::Constant() # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %22 : int[] = prim::Constant[value=[-1]]() # %23 : Tensor = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %3 : Tensor = onnx::Equal(%0, %some_const) # %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3) # %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4) # %19 : Tensor = onnx::Cast[to=9](%12) # %20 : Tensor = onnx::Constant[value={1}]() # %21 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = onnx::Where(%19, %20, %0) # return (%21) # # index_put -> masked_scatter # # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %6 : None = prim::Constant() # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %34 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]() # %35 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %18 : Device = prim::Constant[value="cpu"]() # %19 : None = prim::Constant() # %36 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %22 : None = prim::Constant() # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Tensor = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) # # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %3 : Float(8, strides=[1], requires_grad=0, device=cpu) # = onnx::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %4 : Tensor = onnx::Equal(%0, %some_const) # %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4) # %13 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%5) # %19 : Tensor = onnx::Shape(%0) # %20 : Tensor = onnx::Expand(%13, %19) # %21 : Tensor = onnx::NonZero(%20) # %22 : Tensor = onnx::Transpose[perm=[1, 0]](%21) # %23 : Tensor = onnx::Constant[value={-1}]() # %24 : Tensor = onnx::Reshape(%3, %23) # %25 : Tensor = onnx::Shape(%22) # %27 : Tensor = onnx::Constant[value={0}]() # %28 : Tensor = onnx::Gather[axis=0](%25, %27) # %29 : Tensor = onnx::Constant[value={0}]() # %30 : Tensor = onnx::Unsqueeze[axes=[0]](%29) # %31 : Tensor = onnx::Unsqueeze[axes=[0]](%28) # %32 : Tensor = onnx::Constant[value={0}]() # %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32) # %34 : Tensor = onnx::Slice(%24, %30, %31, %33) # %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = onnx::ScatterND(%0, %22, %34) # return (%35) bool_inp = list(index.node().inputs())[0] if bool_inp.type() is not None and bool_inp.type().scalarType( ) == 'Bool': rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: from torch.onnx.symbolic_opset9 import masked_fill return masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = sym_help._unsqueeze_helper(g, index, [-1]) sub_data_shape = sym_help._slice_helper(g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) values = g.op("Reshape", values, values_shape) if accumulate: dtype = self.type().scalarType() dtype = sym_help.scalar_type_to_onnx.index( sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def multiclass_nms_core_symbolic(g, multi_bboxes, multi_scores, score_thr, nms_cfg, max_num=-1): from torch.onnx.symbolic_opset9 import reshape, squeeze from torch.onnx.symbolic_opset10 import _slice def cast(x, dtype): return g.op('Cast', x, to_i=sym_help.cast_pytorch_to_onnx[dtype]) def get_size(x, dim): shape = g.op('Shape', x) dim = _slice(g, shape, axes=[0], starts=[dim], ends=[dim + 1]) return cast(dim, 'Long') nms_op_type = nms_cfg.get('type', 'nms') assert nms_op_type == 'nms' assert 'iou_thr' in nms_cfg iou_threshold = nms_cfg['iou_thr'] assert 0 <= iou_threshold <= 1 # Transpose and reshape input tensors to fit ONNX NonMaxSuppression. multi_bboxes = reshape(g, multi_bboxes, [0, -1, 4]) multi_bboxes = g.op('Transpose', multi_bboxes, perm_i=[1, 0, 2]) batches_num = get_size(multi_bboxes, 0) spatial_num = get_size(multi_bboxes, 1) multi_scores = g.op('Transpose', multi_scores, perm_i=[1, 0]) scores_shape = g.op('Concat', batches_num, g.op('Constant', value_t=torch.LongTensor([-1])), spatial_num, axis_i=0) multi_scores = reshape(g, multi_scores, scores_shape) classes_num = get_size(multi_scores, 1) assert max_num > 0 indices = g.op( 'NonMaxSuppression', multi_bboxes, multi_scores, g.op('Constant', value_t=torch.LongTensor([max_num])), g.op('Constant', value_t=torch.FloatTensor([iou_threshold])), g.op('Constant', value_t=torch.FloatTensor([score_thr]))) # Flatten bboxes and scores. multi_bboxes_flat = reshape(g, multi_bboxes, [-1, 4]) multi_scores_flat = reshape(g, multi_scores, [ -1, ]) # Flatten indices. batch_indices = _slice(g, indices, axes=[1], starts=[0], ends=[1]) class_indices = _slice(g, indices, axes=[1], starts=[1], ends=[2]) box_indices = _slice(g, indices, axes=[1], starts=[2], ends=[3]) def add(*args, dtype='Long'): x = g.op('Add', args[0], args[1]) if dtype is not None: x = cast(x, dtype) return x def mul(*args, dtype='Long'): x = g.op('Mul', args[0], args[1]) if dtype is not None: x = cast(x, dtype) return x flat_box_indices = add(mul(batch_indices, spatial_num), box_indices) flat_score_indices = add( mul(add(mul(batch_indices, classes_num), class_indices), spatial_num), box_indices) # Select bboxes. out_bboxes = reshape( g, g.op('Gather', multi_bboxes_flat, flat_box_indices, axis_i=0), [-1, 4]) out_scores = reshape( g, g.op('Gather', multi_scores_flat, flat_score_indices, axis_i=0), [-1, 1]) # Having either batch size or number of classes here equal to one is the limitation of implementation. class_indices = reshape(g, cast(add(class_indices, batch_indices), 'Float'), [-1, 1]) # Combine bboxes, scores and labels into a single tensor. # This a workaround for a PyTorch bug (feature?), # limiting ONNX operations to output only single tensor. out_combined_bboxes = g.op('Concat', out_bboxes, out_scores, class_indices, axis_i=1) # Get the top scored bboxes only. elements_num = sym_help._size_helper(g, out_scores, dim=g.op('Constant', value_t=torch.LongTensor( [0]))) max_num = g.op('Constant', value_t=torch.LongTensor([max_num])) if sym_help._export_onnx_opset_version < 12: kn = g.op('Concat', max_num, elements_num, axis_i=0) kn = g.op('ReduceMin', kn, keepdims_i=0) else: kn = g.op('Min', max_num, elements_num) _, top_indices = sym_help._topk_helper(g, out_scores, kn, dim=0) # top_indices = squeeze(g, top_indices, dim=1) top_indices = reshape(g, top_indices, [ -1, ]) out_combined_bboxes = g.op('Gather', out_combined_bboxes, top_indices, axis_i=0) return out_combined_bboxes
def index_put(g, self, indices_list_value, values, accumulate=False): if symbolic_helper._is_packed_list(indices_list_value): indices_list = symbolic_helper._unpack_list(indices_list_value) else: indices_list = [indices_list_value] if symbolic_helper.is_caffe2_aten_fallback(): args = [self] + indices_list + [values, accumulate] return g.at("index_put", *args) accumulate = symbolic_helper._parse_arg(accumulate, "b") if len(indices_list) == 0: return values if len(indices_list) > 1: for idx_ in range(len(indices_list)): if indices_list[idx_].type().scalarType() == "Bool": # type: ignore[attr-defined] # TODO(justinchuby): Remove type ignore after #81112 is checked in. indices_list[idx_] = g.op("NonZero", indices_list[idx_]) index = indices_list[0] for ind in indices_list[1:]: index = opset9.add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ symbolic_helper._unsqueeze_helper( g, opset9.expand(g, ind, broadcast_index_shape, None), [-1] ) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains a single boolean input. # # index_put -> masked_fill # * input index contains single tensor of Bool type (e.g.: %24 <- %23). # * input value contains single element (e.g.: %18). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # # index_put -> masked_scatter # * input index contains single tensor of Bool type (e.g.: %32 <- %31). # * input value contains multiple elements (e.g.: %28). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) index = indices_list[0] bool_inp = index if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool": # type: ignore[attr-defined] # TODO(justinchuby): Remove type ignore after #81112 is checked in. rank = symbolic_helper._get_tensor_rank(values) if rank is not None and rank == 0: return opset9.masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = symbolic_helper._unsqueeze_helper(g, index, [-1]) sub_data_shape = symbolic_helper._slice_helper( g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize] ) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) # Check if values is a singular value and expand accordingly rank = symbolic_helper._get_tensor_rank(values) if rank is not None and rank == 0: values = opset9.expand(g, values, values_shape, None) values = symbolic_helper._reshape_helper(g, values, values_shape) dtype = self.type().scalarType() if dtype is not None and dtype != values.type().scalarType(): values = g.op("Cast", values, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]) dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype] if accumulate: zeros = g.op( "ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype), ) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result