def masked_scatter(g, self, mask, source): from torch.onnx.symbolic_opset9 import nonzero, expand_as, view, size index = nonzero(g, expand_as(g, mask, self)) # NOTE: source can have more elements than needed. # It could also have arbitrary shape. # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. source = view(g, source, torch.LongTensor([-1])) source = sym_help._slice_helper(g, source, axes=torch.LongTensor([0]), starts=torch.LongTensor([0]), ends=size(g, index, torch.LongTensor([0])), dynamic_slice=True) return g.op('ScatterND', self, index, source)
def scatter(g, self, dim, index, src): from torch.onnx.symbolic_opset9 import expand_as if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: return g.op("ATen", self, dim, index, src, operator_s="scatter") src_type = src.type().scalarType() src = sym_help._maybe_get_scalar(src) if sym_help._is_value(src): return g.op("ScatterElements", self, index, src, axis_i=dim) else: # Check if scalar 'src' has same type as self (PyTorch allows different # type for scalar src (but not when src is tensor)). If not, insert Cast node. if self.type().scalarType() != src_type: src = g.op("Cast", src, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()]) return g.op("ScatterElements", self, index, expand_as(g, src, index), axis_i=dim)
def masked_scatter(g, self, mask, source): index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) # NOTE: source can have more elements than needed. # It could also have arbitrary shape. # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1])) source = symbolic_helper._slice_helper( g, source, axes=torch.LongTensor([0]), starts=torch.LongTensor([0]), ends=opset9.size(g, index, torch.LongTensor([0])), dynamic_slice=True, ) return g.op("ScatterND", self, index, source)
def scatter(g, self, dim, index, src): if symbolic_helper.is_caffe2_aten_fallback(): return g.at("scatter", self, dim, index, src, overload_name="src") src_type = src.type().scalarType() src = symbolic_helper._maybe_get_scalar(src) if symbolic_helper._is_value(src): return g.op("ScatterElements", self, index, src, axis_i=dim) else: # Check if scalar "src" has same type as self (PyTorch allows different # type for scalar src (but not when src is tensor)). If not, insert Cast node. if self.type().scalarType() != src_type: src = g.op( "Cast", src, to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()], ) return g.op( "ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim )
def masked_select(g, self, mask): from torch.onnx.symbolic_opset9 import expand_as, nonzero index = nonzero(g, expand_as(g, mask, self)) return g.op('GatherND', self, index)
def masked_select(g, self, mask): index = opset9.nonzero(g, opset9.expand_as(g, mask, self)) return g.op("GatherND", self, index)