def index(g, self, index): if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: return g.op("ATen", self, index, operator_s="index") if sym_help._is_packed_list(index): indices = sym_help._unpack_list(index) else: indices = [index] # Handle single mask index. if len(indices) == 1: index = indices[0] if not sym_help._is_none(index) and ( index.type().scalarType() == "Bool" or index.type().scalarType() == "Byte"): from torch.onnx.symbolic_opset9 import nonzero index = nonzero(g, index) return g.op("GatherND", self, index) from torch.onnx.symbolic_opset9 import index as index_opset9 return index_opset9(g, self, index)
def index(g, self, index): if sym_help.is_caffe2_aten_fallback(): return g.at("index", self, index, overload_name="Tensor") if sym_help._is_packed_list(index): indices = sym_help._unpack_list(index) else: indices = [index] # Handle single mask index. if len(indices) == 1: index = indices[0] if not sym_help._is_none(index) and ( index.type().scalarType() == "Bool" or index.type().scalarType() == "Byte"): from torch.onnx.symbolic_opset9 import nonzero index = nonzero(g, index) return g.op("GatherND", self, index) from torch.onnx.symbolic_opset9 import index as index_opset9 return index_opset9(g, self, index)
def index_put(g, self, indices_list_value, values, accumulate=False): indices_list = sym_help._unpack_list(indices_list_value) if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: args = [self] + indices_list + [values, accumulate] return g.op("ATen", *args, operator_s='index_put') from torch.onnx.symbolic_opset9 import add, expand accumulate = sym_help._parse_arg(accumulate, 'b') index = indices_list[0] if len(indices_list) > 1: for ind in indices_list[1:]: index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ g.op("Unsqueeze", expand(g, ind, broadcast_index_shape, None), axes_i=[-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: broadcast_index_shape = g.op("Shape", index) index = g.op("Unsqueeze", index, axes_i=[-1]) sub_data_shape = sym_help._slice_helper( g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) values = g.op("Reshape", values, values_shape) if accumulate: dtype = self.type().scalarType() dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def index_put(g, self, indices_list_value, values, accumulate=False): if sym_help._is_packed_list(indices_list_value): indices_list = sym_help._unpack_list(indices_list_value) else: indices_list = [indices_list_value] if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: args = [self] + indices_list + [values, accumulate] return g.op("ATen", *args, operator_s='index_put') from torch.onnx.symbolic_opset9 import add, expand accumulate = sym_help._parse_arg(accumulate, 'b') if len(indices_list) == 0: return values index = indices_list[0] if len(indices_list) > 1: for ind in indices_list[1:]: index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ sym_help._unsqueeze_helper( g, expand(g, ind, broadcast_index_shape, None), [-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains boolean inputs # # index_put -> masked_fill # * input index contains single tensor of Bool type (e.g.: %24 <- %23). # * input value contains single element (e.g.: %18). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # # index_put -> masked_scatter # * input index contains single tensor of Bool type (e.g.: %32 <- %31). # * input value contains multiple elements (e.g.: %28). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) bool_inp = index if bool_inp.type() is not None and bool_inp.type().scalarType( ) == 'Bool': rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: from torch.onnx.symbolic_opset9 import masked_fill return masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = sym_help._unsqueeze_helper(g, index, [-1]) sub_data_shape = sym_help._slice_helper(g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) # Check if values is a singular value and expand accordingly rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: values = expand(g, values, values_shape, None) values = g.op("Reshape", values, values_shape) dtype = self.type().scalarType() if dtype is not None and dtype != values.type().scalarType(): values = g.op("Cast", values, to_i=sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_onnx.index( sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] if accumulate: zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def einsum(g, equation, tensor_list): tensors = sym_help._unpack_list(tensor_list) num_ops = len(tensors) assert num_ops > 0 # Doesn't support implicit output is ellipsis or more than 2 oprands for now. # Doesn't support ellipsis ('...') for now as not easy to get sizes of oprands. if num_ops != 2 or equation.find("->") == -1 or "." in equation: return g.op("Einsum", *tensors, equation_s=equation) # Take "ks,ksm->sm" as example. After prcoess inputs, # lhs_labels = [k,s], rhs_labels = [k,s,m], result_labels = [s,m]. lhs_labels, rhs_labels, result_labels = parse_equation(equation) # Doesn't support repeated label in operand for now as it needs to take extra diagonal. if len(lhs_labels) != len(set(lhs_labels)) or len(rhs_labels) != len( set(rhs_labels)): return g.op("Einsum", *tensors, equation_s=equation) # Add contraction labels (labels not present in output). # After process contraction labels, contraction_labels = [k], # label_perm_map = {(s, 0), (m, 1), (k, 2)}, out_size = 2, perm_size = 3. out_size = len(result_labels) label_perm_map = dict([(label, idx) for idx, label in enumerate(result_labels)]) perm_size = out_size contraction_labels = [] lhs_reduce_sum_axes = [] rhs_reduce_sum_axes = [] for label in lhs_labels + rhs_labels: if label not in label_perm_map: if label in lhs_labels and label in rhs_labels: label_perm_map[label] = perm_size contraction_labels.append(label) perm_size += 1 elif label in lhs_labels: lhs_reduce_sum_axes.append(lhs_labels.index(label)) else: rhs_reduce_sum_axes.append(rhs_labels.index(label)) lhs_tensor = tensors[0] rhs_tensor = tensors[1] # If lhs_reduce_sum_axes/rhs_reduce_sum_axes is not empty, ReduceSum on that axes, update lhs_labels/rhs_labels, # and use the output as original_lhs_tensor/original_rhs_tensor. if lhs_reduce_sum_axes: lhs_tensor = sym_help._reducesum_helper(g, lhs_tensor, lhs_reduce_sum_axes, keepdims_i=False) lhs_labels = [ lhs_labels[axis] for axis in range(len(lhs_labels)) if axis not in lhs_reduce_sum_axes ] if rhs_reduce_sum_axes: rhs_tensor = sym_help._reducesum_helper(g, rhs_tensor, rhs_reduce_sum_axes, keepdims_i=False) rhs_labels = [ rhs_labels[axis] for axis in range(len(rhs_labels)) if axis not in rhs_reduce_sum_axes ] # Need to unsqueeze and permute the inputs to order of output with contraction labels. # lhs_perm = [1,2,0], lhs_unsqueeze_axes = [2]. # rhs_perm = [1,2,0], rhs_unsqueeze_axes = []. lhs_perm, lhs_unsqueeze_axes = map_labels_to_output( lhs_labels, label_perm_map) rhs_perm, rhs_unsqueeze_axes = map_labels_to_output( rhs_labels, label_perm_map) # If there is no contraction labels, unsqueeze and permute the inputs and Mul them to get final result. if not contraction_labels: lhs_tensor = unsqueeze_and_permute_for_mul(g, lhs_tensor, lhs_unsqueeze_axes, lhs_perm) rhs_tensor = unsqueeze_and_permute_for_mul(g, rhs_tensor, rhs_unsqueeze_axes, rhs_perm) return g.op("Mul", lhs_tensor, rhs_tensor) # If contraction_labels is not empty, need a BatchedMatMul. # Batched labels are those in all inputs and output. Below axes are based on output. # batched_labels = [s], batched_axes = [0] for the example. # Matmul output labels are those in one of inputs and output. # matmul_output_labels = [m], matmul_output_axes = [1] for the example. # contraction_labels = [k], contraction_axes = [2] for the example. batched_axes = [] matmul_output_axes = [] contraction_axes = [axis for axis in range(out_size, perm_size)] for axis in range(out_size): label = result_labels[axis] if label in lhs_labels and label in rhs_labels: batched_axes.append(axis) else: matmul_output_axes.append(axis) # Based on above unsqueeze and permute on inputs, need to permute again. # For lhs input, the new permute is batched_axes + matmul_output_axes + contraction_axes: [0, 1, 2], # i.e., a.unsqueeze([2]).permute([1,2,0]).permute([0,1,2]) = [s,1,k] for the example. # For rhs input, the new permute is batched_axes + contraction_axes + matmul_output_axes: [0, 2, 1]. # i.e., b.unsqueeze([]).permute([1,2,0]).permute([0,2,1]) = [s,k,m] for the example. lhs_perm = combine_unsqueeze_and_permute_for_matmul( lhs_unsqueeze_axes, lhs_perm, batched_axes + matmul_output_axes + contraction_axes) rhs_perm = combine_unsqueeze_and_permute_for_matmul( rhs_unsqueeze_axes, rhs_perm, batched_axes + contraction_axes + matmul_output_axes) # Need to Reshape two input tensors before the BatchedMatMul and Reshape result to output shape. # Reshape lhs input to [[batched_shapes], Mul(lhs_matmul_output_shapes), Mul(contraction_shapes)]. # Reshape rhs input to [[batched_shapes], Mul(contraction_shapes), Mul(rhs_matmul_output_shapes)] # Convert all axes based on inputs. # lhs_contraction_axes = [0], rhs_contraction_axes = [0], lhs_matmul_output_axes = [], rhs_matmul_output_axes = [2] for the example. lhs_contraction_axes = [ lhs_labels.index(label) for label in contraction_labels ] rhs_contraction_axes = [ rhs_labels.index(label) for label in contraction_labels ] lhs_matmul_output_axes = [ lhs_labels.index(result_labels[axis]) for axis in matmul_output_axes if result_labels[axis] in lhs_labels ] rhs_matmul_output_axes = [ rhs_labels.index(result_labels[axis]) for axis in matmul_output_axes if result_labels[axis] in rhs_labels ] # Caches of input shape tensors to avoid generating duplicated graph. lhs_shape_tensor = None rhs_shape_tensor = None # contraction_numel_tensor should be tensor([size(k)]) for the example, but since length is 1, it's None here. contraction_numel_tensor = None if len(lhs_contraction_axes) > 1: _, contraction_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes( g, lhs_tensor, lhs_shape_tensor, lhs_contraction_axes, True) # Prepare some shape tensors for Reshape if needed. # Both lhs_matmul_output_shape_tensor and lhs_matmul_output_numel_tensor is None for the example. lhs_matmul_output_shape_tensor = None lhs_matmul_output_numel_tensor = None if len(lhs_matmul_output_axes) > 1: lhs_matmul_output_shape_tensor, lhs_matmul_output_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes( g, lhs_tensor, lhs_shape_tensor, lhs_matmul_output_axes, True) # Both rhs_matmul_output_shape_tensor and rhs_matmul_output_numel_tensor is None for the example. rhs_matmul_output_shape_tensor = None rhs_matmul_output_numel_tensor = None if len(rhs_matmul_output_axes) > 1: rhs_matmul_output_shape_tensor, rhs_matmul_output_numel_tensor, rhs_shape_tensor = get_shape_tensor_by_axes( g, rhs_tensor, rhs_shape_tensor, rhs_matmul_output_axes, True) new_lhs_tensor = lhs_tensor # Need to Reshape lhs_tensor if lhs_matmul_output_axes or lhs_contraction_axes is not 1, otherwise permute it directly. # Need to Reshape the lhs_tensor for the example, the new shape is [size(s), 1, size(k)]. if len(lhs_matmul_output_axes) != 1 or len(lhs_contraction_axes) != 1: new_lhs_tensor, lhs_shape_tensor = permute_and_reshape_tensor( g, lhs_tensor, True, len(lhs_labels), lhs_perm, lhs_matmul_output_axes, lhs_contraction_axes, len(batched_axes), lhs_matmul_output_numel_tensor, contraction_numel_tensor, lhs_shape_tensor, ) else: if need_permute(lhs_perm): new_lhs_tensor = g.op("Transpose", lhs_tensor, perm_i=lhs_perm) # Need to Reshape rhs_tensor if rhs_matmul_output_axes or rhs_contraction_axes is not 1, otherwise permute it directly. # rhs_tensor's new shape should be [size(s), size(k), size(m)], but doesn't need to Reshape for the example. new_rhs_tensor = rhs_tensor if len(rhs_matmul_output_axes) != 1 or len(rhs_contraction_axes) != 1: new_rhs_tensor, rhs_shape_tensor = permute_and_reshape_tensor( g, rhs_tensor, False, len(rhs_labels), rhs_perm, rhs_matmul_output_axes, rhs_contraction_axes, len(batched_axes), rhs_matmul_output_numel_tensor, contraction_numel_tensor, rhs_shape_tensor, ) else: if need_permute(rhs_perm): new_rhs_tensor = g.op("Transpose", rhs_tensor, perm_i=rhs_perm) # Perform final BatchedMatMul. Output is shape [size(s), 1, size(m)] for the example. result = g.op("MatMul", new_lhs_tensor, new_rhs_tensor) # Need to Reshape the result if lhs_matmul_output_axes or rhs_matmul_output_axes is not 1. # Need to Reshape the result for the example, the new shape is [size(s), size(m)]. if len(lhs_matmul_output_axes) != 1 or len(rhs_matmul_output_axes) != 1: shape_tensors = [ g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) ] * len(batched_axes) last_zero_dim = len(shape_tensors) - 1 has_neg_one_dim = False if lhs_matmul_output_axes: if len(lhs_matmul_output_axes) == 1: shape_tensors.append( g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))) last_zero_dim = len(shape_tensors) - 1 else: shape_tensors.append(lhs_matmul_output_shape_tensor) if rhs_matmul_output_axes: if len(rhs_matmul_output_axes) == 1: shape_tensors.append( g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) has_neg_one_dim = True else: shape_tensors.append(rhs_matmul_output_shape_tensor) if not has_neg_one_dim and last_zero_dim >= 0: shape_tensors[last_zero_dim] = g.op("Constant", value_t=torch.tensor( [-1], dtype=torch.int64)) result = reshape_tensor(g, result, shape_tensors) # Now output axes is ordered by [batched_axes, lhs_matmul_output_axes, rhs_matmut_output_axes], # if this is not same as output, need one permute. labels = ([result_labels[axis] for axis in batched_axes] + [lhs_labels[axis] for axis in lhs_matmul_output_axes] + [rhs_labels[axis] for axis in rhs_matmul_output_axes]) assert len(labels) == out_size output_perm = [labels.index(label) for label in result_labels] assert all(axis in output_perm for axis in range(out_size)) if need_permute(output_perm): result = g.op("Transpose", result, perm_i=output_perm) return result
def einsum(g, equation, tensor_list): tensors = sym_help._unpack_list(tensor_list) return g.op("Einsum", *tensors, equation_s=equation)
def index_put(g, self, indices_list_value, values, accumulate=False): indices_list = sym_help._unpack_list(indices_list_value) if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: args = [self] + indices_list + [values, accumulate] return g.op("ATen", *args, operator_s='index_put') from torch.onnx.symbolic_opset9 import add, expand accumulate = sym_help._parse_arg(accumulate, 'b') index = indices_list[0] if len(indices_list) > 1: for ind in indices_list[1:]: index = add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ sym_help._unsqueeze_helper( g, expand(g, ind, broadcast_index_shape, None), [-1]) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains boolean inputs # # index_put -> masked_fill # # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %6 : None = prim::Constant() # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %8 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::ne(%mask, %some_const) # %26 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]() # %27 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %11 : Device = prim::Constant[value="cpu"]() # %12 : None = prim::Constant() # %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %15 : None = prim::Constant() # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %22 : int[] = prim::Constant[value=[-1]]() # %23 : Tensor = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %3 : Tensor = onnx::Equal(%0, %some_const) # %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3) # %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4) # %19 : Tensor = onnx::Cast[to=9](%12) # %20 : Tensor = onnx::Constant[value={1}]() # %21 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = onnx::Where(%19, %20, %0) # return (%21) # # index_put -> masked_scatter # # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %6 : None = prim::Constant() # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %34 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]() # %35 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %18 : Device = prim::Constant[value="cpu"]() # %19 : None = prim::Constant() # %36 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %22 : None = prim::Constant() # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Tensor = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) # # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu), # %some_const : Float(requires_grad=0, device=cpu)): # %3 : Float(8, strides=[1], requires_grad=0, device=cpu) # = onnx::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %4 : Tensor = onnx::Equal(%0, %some_const) # %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4) # %13 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%5) # %19 : Tensor = onnx::Shape(%0) # %20 : Tensor = onnx::Expand(%13, %19) # %21 : Tensor = onnx::NonZero(%20) # %22 : Tensor = onnx::Transpose[perm=[1, 0]](%21) # %23 : Tensor = onnx::Constant[value={-1}]() # %24 : Tensor = onnx::Reshape(%3, %23) # %25 : Tensor = onnx::Shape(%22) # %27 : Tensor = onnx::Constant[value={0}]() # %28 : Tensor = onnx::Gather[axis=0](%25, %27) # %29 : Tensor = onnx::Constant[value={0}]() # %30 : Tensor = onnx::Unsqueeze[axes=[0]](%29) # %31 : Tensor = onnx::Unsqueeze[axes=[0]](%28) # %32 : Tensor = onnx::Constant[value={0}]() # %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32) # %34 : Tensor = onnx::Slice(%24, %30, %31, %33) # %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = onnx::ScatterND(%0, %22, %34) # return (%35) bool_inp = list(index.node().inputs())[0] if bool_inp.type() is not None and bool_inp.type().scalarType( ) == 'Bool': rank = sym_help._get_tensor_rank(values) if rank is not None and rank == 0: from torch.onnx.symbolic_opset9 import masked_fill return masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = sym_help._unsqueeze_helper(g, index, [-1]) sub_data_shape = sym_help._slice_helper(g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize]) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) values = g.op("Reshape", values, values_shape) if accumulate: dtype = self.type().scalarType() dtype = sym_help.scalar_type_to_onnx.index( sym_help.cast_pytorch_to_onnx[dtype]) dtype = sym_help.scalar_type_to_pytorch_type[dtype] zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype)) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def einsum(g, equation, tensor_list): tensors = sym_help._unpack_list(tensor_list) return einsum_helper(g, equation, tensors)
def index_put(g, self, indices_list_value, values, accumulate=False): if symbolic_helper._is_packed_list(indices_list_value): indices_list = symbolic_helper._unpack_list(indices_list_value) else: indices_list = [indices_list_value] if symbolic_helper.is_caffe2_aten_fallback(): args = [self] + indices_list + [values, accumulate] return g.at("index_put", *args) accumulate = symbolic_helper._parse_arg(accumulate, "b") if len(indices_list) == 0: return values if len(indices_list) > 1: for idx_ in range(len(indices_list)): if indices_list[idx_].type().scalarType() == "Bool": # type: ignore[attr-defined] # TODO(justinchuby): Remove type ignore after #81112 is checked in. indices_list[idx_] = g.op("NonZero", indices_list[idx_]) index = indices_list[0] for ind in indices_list[1:]: index = opset9.add(g, index, ind) broadcast_index_shape = g.op("Shape", index) indices_list = [ symbolic_helper._unsqueeze_helper( g, opset9.expand(g, ind, broadcast_index_shape, None), [-1] ) for ind in indices_list ] index = g.op("Concat", *indices_list, axis_i=-1) else: # Replace index_put node with masked_scatter or masked_fill # when inputs to the index_put node contains a single boolean input. # # index_put -> masked_fill # * input index contains single tensor of Bool type (e.g.: %24 <- %23). # * input value contains single element (e.g.: %18). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::to(%8, %26, %27, %11, %12, %28, %29, %15) # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22) # %24 : Tensor?[] = prim::ListConstruct(%23) # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = # aten::index_put(%mask, %24, %18, %30) # return (%25) # # # index_put -> masked_scatter # * input index contains single tensor of Bool type (e.g.: %32 <- %31). # * input value contains multiple elements (e.g.: %28). # # Torch IR # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6) # %28 : Float(8, strides=[1], requires_grad=0, device=cpu) # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]() # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::ne(%mask, %some_const) # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22) # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() # %30 : int[] = prim::Constant[value=[-1]]() # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30) # %32 : Tensor?[] = prim::ListConstruct(%31) # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) # = aten::index_put(%mask, %32, %28, %38) # return (%33) index = indices_list[0] bool_inp = index if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool": # type: ignore[attr-defined] # TODO(justinchuby): Remove type ignore after #81112 is checked in. rank = symbolic_helper._get_tensor_rank(values) if rank is not None and rank == 0: return opset9.masked_fill(g, self, bool_inp, values) return masked_scatter(g, self, bool_inp, values) broadcast_index_shape = g.op("Shape", index) index = symbolic_helper._unsqueeze_helper(g, index, [-1]) sub_data_shape = symbolic_helper._slice_helper( g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize] ) values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0) # Check if values is a singular value and expand accordingly rank = symbolic_helper._get_tensor_rank(values) if rank is not None and rank == 0: values = opset9.expand(g, values, values_shape, None) values = symbolic_helper._reshape_helper(g, values, values_shape) dtype = self.type().scalarType() if dtype is not None and dtype != values.type().scalarType(): values = g.op("Cast", values, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype]) dtype = symbolic_helper.scalar_type_to_onnx.index( symbolic_helper.cast_pytorch_to_onnx[dtype] ) dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype] if accumulate: zeros = g.op( "ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype), ) result = g.op("ScatterND", zeros, index, values) result = add(g, self, result) else: result = g.op("ScatterND", self, index, values) return result
def einsum(g, equation, tensor_list): tensors = symbolic_helper._unpack_list(tensor_list) return _einsum_helper(g, equation, tensors)