コード例 #1
def index(g, self, index):
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        return g.op("ATen", self, index, operator_s="index")

    if sym_help._is_packed_list(index):
        indices = sym_help._unpack_list(index)
        indices = [index]

    # Handle single mask index.
    if len(indices) == 1:
        index = indices[0]
        if not sym_help._is_none(index) and (
                index.type().scalarType() == "Bool"
                or index.type().scalarType() == "Byte"):
            from torch.onnx.symbolic_opset9 import nonzero
            index = nonzero(g, index)
            return g.op("GatherND", self, index)
    from torch.onnx.symbolic_opset9 import index as index_opset9
    return index_opset9(g, self, index)
コード例 #2
def index(g, self, index):
    if sym_help.is_caffe2_aten_fallback():
        return g.at("index", self, index, overload_name="Tensor")

    if sym_help._is_packed_list(index):
        indices = sym_help._unpack_list(index)
        indices = [index]

    # Handle single mask index.
    if len(indices) == 1:
        index = indices[0]
        if not sym_help._is_none(index) and (
                index.type().scalarType() == "Bool"
                or index.type().scalarType() == "Byte"):
            from torch.onnx.symbolic_opset9 import nonzero

            index = nonzero(g, index)
            return g.op("GatherND", self, index)
    from torch.onnx.symbolic_opset9 import index as index_opset9

    return index_opset9(g, self, index)
コード例 #3
def index_put(g, self, indices_list_value, values, accumulate=False):
    indices_list = sym_help._unpack_list(indices_list_value)
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
            g.op("Unsqueeze", expand(g, ind, broadcast_index_shape, None), axes_i=[-1]) for ind in indices_list
        index = g.op("Concat", *indices_list, axis_i=-1)
        broadcast_index_shape = g.op("Shape", index)
        index = g.op("Unsqueeze", index, axes_i=[-1])
    sub_data_shape = sym_help._slice_helper(
        g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize])
    values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
    values = g.op("Reshape", values, values_shape)

    if accumulate:
        dtype = self.type().scalarType()
        dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
        dtype = sym_help.scalar_type_to_pytorch_type[dtype]
        zeros = g.op("ConstantOfShape", g.op("Shape", self), value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
        result = g.op("ScatterND", self, index, values)

    return result
コード例 #4
def index_put(g, self, indices_list_value, values, accumulate=False):
    if sym_help._is_packed_list(indices_list_value):
        indices_list = sym_help._unpack_list(indices_list_value)
        indices_list = [indices_list_value]
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    if len(indices_list) == 0:
        return values

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
                g, expand(g, ind, broadcast_index_shape, None), [-1])
            for ind in indices_list
        index = g.op("Concat", *indices_list, axis_i=-1)
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains boolean inputs
        # index_put -> masked_fill
        #   * input index contains single tensor of Bool type (e.g.: %24 <- %23).
        #   * input value contains single element (e.g.: %18).
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        # index_put -> masked_scatter
        #   * input index contains single tensor of Bool type (e.g.: %32 <- %31).
        #   * input value contains multiple elements (e.g.: %28).
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        bool_inp = index
        if bool_inp.type() is not None and bool_inp.type().scalarType(
        ) == 'Bool':
            rank = sym_help._get_tensor_rank(values)
            if rank is not None and rank == 0:
                from torch.onnx.symbolic_opset9 import masked_fill
                return masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = sym_help._unsqueeze_helper(g, index, [-1])
    sub_data_shape = sym_help._slice_helper(g,
                                            g.op("Shape", self),
    values_shape = g.op("Concat",
    # Check if values is a singular value and expand accordingly
    rank = sym_help._get_tensor_rank(values)
    if rank is not None and rank == 0:
        values = expand(g, values, values_shape, None)
    values = g.op("Reshape", values, values_shape)

    dtype = self.type().scalarType()
    if dtype is not None and dtype != values.type().scalarType():
        values = g.op("Cast",
    dtype = sym_help.scalar_type_to_onnx.index(
    dtype = sym_help.scalar_type_to_pytorch_type[dtype]

    if accumulate:
        zeros = g.op("ConstantOfShape",
                     g.op("Shape", self),
                     value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
        result = g.op("ScatterND", self, index, values)

    return result
コード例 #5
def einsum(g, equation, tensor_list):
    tensors = sym_help._unpack_list(tensor_list)
    num_ops = len(tensors)
    assert num_ops > 0

    # Doesn't support implicit output is ellipsis or more than 2 oprands for now.
    # Doesn't support ellipsis ('...') for now as not easy to get sizes of oprands.
    if num_ops != 2 or equation.find("->") == -1 or "." in equation:
        return g.op("Einsum", *tensors, equation_s=equation)

    # Take "ks,ksm->sm" as example. After prcoess inputs,
    # lhs_labels = [k,s], rhs_labels = [k,s,m], result_labels = [s,m].
    lhs_labels, rhs_labels, result_labels = parse_equation(equation)

    # Doesn't support repeated label in operand for now as it needs to take extra diagonal.
    if len(lhs_labels) != len(set(lhs_labels)) or len(rhs_labels) != len(
        return g.op("Einsum", *tensors, equation_s=equation)

    # Add contraction labels (labels not present in output).
    # After process contraction labels, contraction_labels = [k],
    # label_perm_map = {(s, 0), (m, 1), (k, 2)}, out_size = 2, perm_size = 3.
    out_size = len(result_labels)
    label_perm_map = dict([(label, idx)
                           for idx, label in enumerate(result_labels)])
    perm_size = out_size
    contraction_labels = []
    lhs_reduce_sum_axes = []
    rhs_reduce_sum_axes = []
    for label in lhs_labels + rhs_labels:
        if label not in label_perm_map:
            if label in lhs_labels and label in rhs_labels:
                label_perm_map[label] = perm_size
                perm_size += 1
            elif label in lhs_labels:

    lhs_tensor = tensors[0]
    rhs_tensor = tensors[1]

    # If lhs_reduce_sum_axes/rhs_reduce_sum_axes is not empty, ReduceSum on that axes, update lhs_labels/rhs_labels,
    # and use the output as original_lhs_tensor/original_rhs_tensor.
    if lhs_reduce_sum_axes:
        lhs_tensor = sym_help._reducesum_helper(g,
        lhs_labels = [
            lhs_labels[axis] for axis in range(len(lhs_labels))
            if axis not in lhs_reduce_sum_axes

    if rhs_reduce_sum_axes:
        rhs_tensor = sym_help._reducesum_helper(g,
        rhs_labels = [
            rhs_labels[axis] for axis in range(len(rhs_labels))
            if axis not in rhs_reduce_sum_axes

    # Need to unsqueeze and permute the inputs to order of output with contraction labels.
    # lhs_perm = [1,2,0], lhs_unsqueeze_axes = [2].
    # rhs_perm = [1,2,0], rhs_unsqueeze_axes = [].
    lhs_perm, lhs_unsqueeze_axes = map_labels_to_output(
        lhs_labels, label_perm_map)
    rhs_perm, rhs_unsqueeze_axes = map_labels_to_output(
        rhs_labels, label_perm_map)

    # If there is no contraction labels, unsqueeze and permute the inputs and Mul them to get final result.
    if not contraction_labels:
        lhs_tensor = unsqueeze_and_permute_for_mul(g, lhs_tensor,
        rhs_tensor = unsqueeze_and_permute_for_mul(g, rhs_tensor,
        return g.op("Mul", lhs_tensor, rhs_tensor)

    # If contraction_labels is not empty, need a BatchedMatMul.
    # Batched labels are those in all inputs and output. Below axes are based on output.
    # batched_labels = [s], batched_axes = [0] for the example.
    # Matmul output labels are those in one of inputs and output.
    # matmul_output_labels = [m], matmul_output_axes = [1] for the example.
    # contraction_labels = [k], contraction_axes = [2] for the example.
    batched_axes = []
    matmul_output_axes = []
    contraction_axes = [axis for axis in range(out_size, perm_size)]
    for axis in range(out_size):
        label = result_labels[axis]
        if label in lhs_labels and label in rhs_labels:

    # Based on above unsqueeze and permute on inputs, need to permute again.
    # For lhs input, the new permute is batched_axes + matmul_output_axes + contraction_axes: [0, 1, 2],
    # i.e., a.unsqueeze([2]).permute([1,2,0]).permute([0,1,2]) = [s,1,k] for the example.
    # For rhs input, the new permute is batched_axes + contraction_axes + matmul_output_axes: [0, 2, 1].
    # i.e., b.unsqueeze([]).permute([1,2,0]).permute([0,2,1]) = [s,k,m] for the example.
    lhs_perm = combine_unsqueeze_and_permute_for_matmul(
        lhs_unsqueeze_axes, lhs_perm,
        batched_axes + matmul_output_axes + contraction_axes)
    rhs_perm = combine_unsqueeze_and_permute_for_matmul(
        rhs_unsqueeze_axes, rhs_perm,
        batched_axes + contraction_axes + matmul_output_axes)

    # Need to Reshape two input tensors before the BatchedMatMul and Reshape result to output shape.
    # Reshape lhs input to [[batched_shapes], Mul(lhs_matmul_output_shapes), Mul(contraction_shapes)].
    # Reshape rhs input to [[batched_shapes], Mul(contraction_shapes), Mul(rhs_matmul_output_shapes)]
    # Convert all axes based on inputs.
    # lhs_contraction_axes = [0], rhs_contraction_axes = [0], lhs_matmul_output_axes = [], rhs_matmul_output_axes = [2] for the example.
    lhs_contraction_axes = [
        lhs_labels.index(label) for label in contraction_labels
    rhs_contraction_axes = [
        rhs_labels.index(label) for label in contraction_labels
    lhs_matmul_output_axes = [
        lhs_labels.index(result_labels[axis]) for axis in matmul_output_axes
        if result_labels[axis] in lhs_labels
    rhs_matmul_output_axes = [
        rhs_labels.index(result_labels[axis]) for axis in matmul_output_axes
        if result_labels[axis] in rhs_labels

    # Caches of input shape tensors to avoid generating duplicated graph.
    lhs_shape_tensor = None
    rhs_shape_tensor = None

    # contraction_numel_tensor should be tensor([size(k)]) for the example, but since length is 1, it's None here.
    contraction_numel_tensor = None
    if len(lhs_contraction_axes) > 1:
        _, contraction_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes(
            g, lhs_tensor, lhs_shape_tensor, lhs_contraction_axes, True)

    # Prepare some shape tensors for Reshape if needed.
    # Both lhs_matmul_output_shape_tensor and lhs_matmul_output_numel_tensor is None for the example.
    lhs_matmul_output_shape_tensor = None
    lhs_matmul_output_numel_tensor = None
    if len(lhs_matmul_output_axes) > 1:
        lhs_matmul_output_shape_tensor, lhs_matmul_output_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes(
            g, lhs_tensor, lhs_shape_tensor, lhs_matmul_output_axes, True)

    # Both rhs_matmul_output_shape_tensor and rhs_matmul_output_numel_tensor is None for the example.
    rhs_matmul_output_shape_tensor = None
    rhs_matmul_output_numel_tensor = None
    if len(rhs_matmul_output_axes) > 1:
        rhs_matmul_output_shape_tensor, rhs_matmul_output_numel_tensor, rhs_shape_tensor = get_shape_tensor_by_axes(
            g, rhs_tensor, rhs_shape_tensor, rhs_matmul_output_axes, True)

    new_lhs_tensor = lhs_tensor
    # Need to Reshape lhs_tensor if lhs_matmul_output_axes or lhs_contraction_axes is not 1, otherwise permute it directly.
    # Need to Reshape the lhs_tensor for the example, the new shape is [size(s), 1, size(k)].
    if len(lhs_matmul_output_axes) != 1 or len(lhs_contraction_axes) != 1:
        new_lhs_tensor, lhs_shape_tensor = permute_and_reshape_tensor(
        if need_permute(lhs_perm):
            new_lhs_tensor = g.op("Transpose", lhs_tensor, perm_i=lhs_perm)

    # Need to Reshape rhs_tensor if rhs_matmul_output_axes or rhs_contraction_axes is not 1, otherwise permute it directly.
    # rhs_tensor's new shape should be [size(s), size(k), size(m)], but doesn't need to Reshape for the example.
    new_rhs_tensor = rhs_tensor
    if len(rhs_matmul_output_axes) != 1 or len(rhs_contraction_axes) != 1:
        new_rhs_tensor, rhs_shape_tensor = permute_and_reshape_tensor(
        if need_permute(rhs_perm):
            new_rhs_tensor = g.op("Transpose", rhs_tensor, perm_i=rhs_perm)

    # Perform final BatchedMatMul. Output is shape [size(s), 1, size(m)] for the example.
    result = g.op("MatMul", new_lhs_tensor, new_rhs_tensor)

    # Need to Reshape the result if lhs_matmul_output_axes or rhs_matmul_output_axes is not 1.
    # Need to Reshape the result for the example, the new shape is [size(s), size(m)].
    if len(lhs_matmul_output_axes) != 1 or len(rhs_matmul_output_axes) != 1:
        shape_tensors = [
            g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))
        ] * len(batched_axes)
        last_zero_dim = len(shape_tensors) - 1
        has_neg_one_dim = False
        if lhs_matmul_output_axes:
            if len(lhs_matmul_output_axes) == 1:
                         value_t=torch.tensor([0], dtype=torch.int64)))
                last_zero_dim = len(shape_tensors) - 1
        if rhs_matmul_output_axes:
            if len(rhs_matmul_output_axes) == 1:
                         value_t=torch.tensor([-1], dtype=torch.int64)))
                has_neg_one_dim = True
        if not has_neg_one_dim and last_zero_dim >= 0:
            shape_tensors[last_zero_dim] = g.op("Constant",
                                                    [-1], dtype=torch.int64))
        result = reshape_tensor(g, result, shape_tensors)

    # Now output axes is ordered by [batched_axes, lhs_matmul_output_axes, rhs_matmut_output_axes],
    # if this is not same as output, need one permute.
    labels = ([result_labels[axis] for axis in batched_axes] +
              [lhs_labels[axis] for axis in lhs_matmul_output_axes] +
              [rhs_labels[axis] for axis in rhs_matmul_output_axes])
    assert len(labels) == out_size
    output_perm = [labels.index(label) for label in result_labels]
    assert all(axis in output_perm for axis in range(out_size))
    if need_permute(output_perm):
        result = g.op("Transpose", result, perm_i=output_perm)

    return result
コード例 #6
def einsum(g, equation, tensor_list):
    tensors = sym_help._unpack_list(tensor_list)
    return g.op("Einsum", *tensors, equation_s=equation)
コード例 #7
def index_put(g, self, indices_list_value, values, accumulate=False):
    indices_list = sym_help._unpack_list(indices_list_value)
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
                g, expand(g, ind, broadcast_index_shape, None), [-1])
            for ind in indices_list
        index = g.op("Concat", *indices_list, axis_i=-1)
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains boolean inputs
        # index_put -> masked_fill
        # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %6 : None = prim::Constant()
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %8 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::ne(%mask, %some_const)
        #   %26 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]()
        #   %27 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %11 : Device = prim::Constant[value="cpu"]()
        #   %12 : None = prim::Constant()
        #   %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %15 : None = prim::Constant()
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %22 : int[] = prim::Constant[value=[-1]]()
        #   %23 : Tensor = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %3 : Tensor = onnx::Equal(%0, %some_const)
        #   %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3)
        #   %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4)
        #   %19 : Tensor = onnx::Cast[to=9](%12)
        #   %20 : Tensor = onnx::Constant[value={1}]()
        #   %21 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = onnx::Where(%19, %20, %0)
        #   return (%21)
        # index_put -> masked_scatter
        # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %6 : None = prim::Constant()
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %34 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]()
        #   %35 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %18 : Device = prim::Constant[value="cpu"]()
        #   %19 : None = prim::Constant()
        #   %36 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %22 : None = prim::Constant()
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Tensor = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %3 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #               = onnx::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %4 : Tensor = onnx::Equal(%0, %some_const)
        #   %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4)
        #   %13 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%5)
        #   %19 : Tensor = onnx::Shape(%0)
        #   %20 : Tensor = onnx::Expand(%13, %19)
        #   %21 : Tensor = onnx::NonZero(%20)
        #   %22 : Tensor = onnx::Transpose[perm=[1, 0]](%21)
        #   %23 : Tensor = onnx::Constant[value={-1}]()
        #   %24 : Tensor = onnx::Reshape(%3, %23)
        #   %25 : Tensor = onnx::Shape(%22)
        #   %27 : Tensor = onnx::Constant[value={0}]()
        #   %28 : Tensor = onnx::Gather[axis=0](%25, %27)
        #   %29 : Tensor = onnx::Constant[value={0}]()
        #   %30 : Tensor = onnx::Unsqueeze[axes=[0]](%29)
        #   %31 : Tensor = onnx::Unsqueeze[axes=[0]](%28)
        #   %32 : Tensor = onnx::Constant[value={0}]()
        #   %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32)
        #   %34 : Tensor = onnx::Slice(%24, %30, %31, %33)
        #   %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = onnx::ScatterND(%0, %22, %34)
        #   return (%35)

        bool_inp = list(index.node().inputs())[0]
        if bool_inp.type() is not None and bool_inp.type().scalarType(
        ) == 'Bool':
            rank = sym_help._get_tensor_rank(values)
            if rank is not None and rank == 0:
                from torch.onnx.symbolic_opset9 import masked_fill
                return masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = sym_help._unsqueeze_helper(g, index, [-1])
    sub_data_shape = sym_help._slice_helper(g,
                                            g.op("Shape", self),
    values_shape = g.op("Concat",
    values = g.op("Reshape", values, values_shape)

    if accumulate:
        dtype = self.type().scalarType()
        dtype = sym_help.scalar_type_to_onnx.index(
        dtype = sym_help.scalar_type_to_pytorch_type[dtype]
        zeros = g.op("ConstantOfShape",
                     g.op("Shape", self),
                     value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
        result = g.op("ScatterND", self, index, values)

    return result
コード例 #8
def einsum(g, equation, tensor_list):
    tensors = sym_help._unpack_list(tensor_list)
    return einsum_helper(g, equation, tensors)
コード例 #9
def index_put(g, self, indices_list_value, values, accumulate=False):
    if symbolic_helper._is_packed_list(indices_list_value):
        indices_list = symbolic_helper._unpack_list(indices_list_value)
        indices_list = [indices_list_value]
    if symbolic_helper.is_caffe2_aten_fallback():
        args = [self] + indices_list + [values, accumulate]
        return g.at("index_put", *args)

    accumulate = symbolic_helper._parse_arg(accumulate, "b")

    if len(indices_list) == 0:
        return values

    if len(indices_list) > 1:
        for idx_ in range(len(indices_list)):
            if indices_list[idx_].type().scalarType() == "Bool":  # type: ignore[attr-defined]
                # TODO(justinchuby): Remove type ignore after #81112 is checked in.
                indices_list[idx_] = g.op("NonZero", indices_list[idx_])
        index = indices_list[0]

        for ind in indices_list[1:]:
            index = opset9.add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
                g, opset9.expand(g, ind, broadcast_index_shape, None), [-1]
            for ind in indices_list
        index = g.op("Concat", *indices_list, axis_i=-1)
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains a single boolean input.
        # index_put -> masked_fill
        #   * input index contains single tensor of Bool type (e.g.: %24 <- %23).
        #   * input value contains single element (e.g.: %18).
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        # index_put -> masked_scatter
        #   * input index contains single tensor of Bool type (e.g.: %32 <- %31).
        #   * input value contains multiple elements (e.g.: %28).
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        index = indices_list[0]
        bool_inp = index
        if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool":  # type: ignore[attr-defined]
            # TODO(justinchuby): Remove type ignore after #81112 is checked in.
            rank = symbolic_helper._get_tensor_rank(values)
            if rank is not None and rank == 0:
                return opset9.masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = symbolic_helper._unsqueeze_helper(g, index, [-1])
    sub_data_shape = symbolic_helper._slice_helper(
        g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize]
    values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
    # Check if values is a singular value and expand accordingly
    rank = symbolic_helper._get_tensor_rank(values)
    if rank is not None and rank == 0:
        values = opset9.expand(g, values, values_shape, None)
    values = symbolic_helper._reshape_helper(g, values, values_shape)

    dtype = self.type().scalarType()
    if dtype is not None and dtype != values.type().scalarType():
        values = g.op("Cast", values, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
    dtype = symbolic_helper.scalar_type_to_onnx.index(
    dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype]

    if accumulate:
        zeros = g.op(
            g.op("Shape", self),
            value_t=torch.tensor([0], dtype=dtype),
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
        result = g.op("ScatterND", self, index, values)

    return result
コード例 #10
ファイル: symbolic_opset12.py プロジェクト: huaxz1986/pytorch
def einsum(g, equation, tensor_list):
    tensors = symbolic_helper._unpack_list(tensor_list)
    return _einsum_helper(g, equation, tensors)