def _index_fill_reshape_helper(g, self, dim, index):
    # 1. reshape index => [1, ..., 1, dim, 1, ..., 1]
    # 2. expand index => [..., dim, ...], same shape as self except for dim.
    # 3. expand value as well.
    # 4. apply onnx::scatter.

    from torch.onnx.symbolic_opset9 import expand
    if _export_onnx_opset_version <= 10:
        from torch.onnx.symbolic_opset9 import scatter
        from torch.onnx.symbolic_opset11 import scatter

    if self.type().dim() is None:
        return _unimplemented("index_fill", "input rank not accesible")
    self_dim = self.type().dim()
    dim_value = _parse_arg(dim, 'i')
    unsqueezed_index = g.op(
        axes_i=[i for i in range(self_dim) if i != dim_value])
    expanded_index_shape = scatter(g, g.op("Shape", self), 0,
                                   g.op("Unsqueeze", dim, axes_i=[0]),
                                   g.op("Shape", index))
    expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None)
    return expanded_index_shape, expanded_index
def index_fill(g, self, dim, index, value):
    dim_value = sym_help._parse_arg(dim, 'i')
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        return g.op("ATen", self, index, value, dim_i=dim_value, operator_s="index_fill")
    expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(g, self, dim, index)
    value = sym_help._maybe_get_scalar(value)
    value = sym_help._if_scalar_type_as(g, value, self)
    expanded_value = expand(g, value, expanded_index_shape, None)
    return scatter(g, self, dim, expanded_index, expanded_value)
def chunk(g, self, chunks, dim):
    # Calculate chunk size for dynamic chunk
    dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0)
    chunk_size_s = g.op("Sub", chunks, g.op("Constant", value_t=torch.tensor([1], dtype=torch.long)))
    chunk_size = g.op("Div", g.op("Add", dim_size, chunk_size_s), chunks)
    # Create splits vector
    chunk_vec = [expand(g, chunk_size, chunk_size_s, None),
                 g.op("Sub", dim_size, g.op("Mul", chunk_size, chunk_size_s))]
    chunk_vec = g.op("Concat", *chunk_vec, axis_i=0)
    return split(g, self, chunk_vec, dim)
def index_put(g, self, indices_list_value, values, accumulate=False):
    indices_list = sym_help._unpack_list(indices_list_value)
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
                 expand(g, ind, broadcast_index_shape, None),
                 axes_i=[-1]) for ind in indices_list
        index = g.op("Concat", *indices_list, axis_i=-1)
        broadcast_index_shape = g.op("Shape", index)
        index = g.op("Unsqueeze", index, axes_i=[-1])
    sub_data_shape = sym_help._slice_helper(g,
                                            g.op("Shape", self),
    values_shape = g.op("Concat",
    values = g.op("Reshape", values, values_shape)

    if accumulate:
        dtype = self.type().scalarType()
        dtype = sym_help.scalar_type_to_onnx.index(
        dtype = sym_help.scalar_type_to_pytorch_type[dtype]
        zeros = g.op("ConstantOfShape",
                     g.op("Shape", self),
                     value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
        result = g.op("ScatterND", self, index, values)

    return result
def index_fill(g, self, dim, index, value):
    dim_value = sym_help._parse_arg(dim, "i")
    if sym_help.is_caffe2_aten_fallback():
        return g.at("index_fill",

    expanded_index_shape, expanded_index = sym_help._index_fill_reshape_helper(
        g, self, dim, index)
    value = sym_help._maybe_get_scalar(value)
    value = sym_help._if_scalar_type_as(g, value, self)
    expanded_value = expand(g, value, expanded_index_shape, None)
    return scatter(g, self, dim, expanded_index, expanded_value)
def index_put(g, self, indices_list_value, values, accumulate=False):
    if sym_help._is_packed_list(indices_list_value):
        indices_list = sym_help._unpack_list(indices_list_value)
        indices_list = [indices_list_value]
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    if len(indices_list) == 0:
        return values

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
                g, expand(g, ind, broadcast_index_shape, None), [-1])
            for ind in indices_list
        index = g.op("Concat", *indices_list, axis_i=-1)
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains boolean inputs
        # index_put -> masked_fill
        #   * input index contains single tensor of Bool type (e.g.: %24 <- %23).
        #   * input value contains single element (e.g.: %18).
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        # index_put -> masked_scatter
        #   * input index contains single tensor of Bool type (e.g.: %32 <- %31).
        #   * input value contains multiple elements (e.g.: %28).
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        bool_inp = index
        if bool_inp.type() is not None and bool_inp.type().scalarType(
        ) == 'Bool':
            rank = sym_help._get_tensor_rank(values)
            if rank is not None and rank == 0:
                from torch.onnx.symbolic_opset9 import masked_fill
                return masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = sym_help._unsqueeze_helper(g, index, [-1])
    sub_data_shape = sym_help._slice_helper(g,
                                            g.op("Shape", self),
    values_shape = g.op("Concat",
    # Check if values is a singular value and expand accordingly
    rank = sym_help._get_tensor_rank(values)
    if rank is not None and rank == 0:
        values = expand(g, values, values_shape, None)
    values = g.op("Reshape", values, values_shape)

    dtype = self.type().scalarType()
    if dtype is not None and dtype != values.type().scalarType():
        values = g.op("Cast",
    dtype = sym_help.scalar_type_to_onnx.index(
    dtype = sym_help.scalar_type_to_pytorch_type[dtype]

    if accumulate:
        zeros = g.op("ConstantOfShape",
                     g.op("Shape", self),
                     value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
        result = g.op("ScatterND", self, index, values)

    return result
def repeat_interleave(g, self, repeats, dim=None):
    from torch.onnx.symbolic_opset9 import reshape
    input = self
    final_dim = dim
    # if dim is None flatten
    # By default, use the flattened input array, and return a flat output array
    if sym_help._is_none(dim):
        input = reshape(g, self, g.op("Constant", value_t=torch.tensor([-1])))
        dim = 0
        dim = sym_help._maybe_get_scalar(dim)

    repeats_dim = sym_help._get_tensor_rank(repeats)
    repeats_sizes = sym_help._get_tensor_sizes(repeats)
    input_sizes = sym_help._get_tensor_sizes(input)
    if repeats_dim is None:
        raise RuntimeError(
            'Unsupported: ONNX export of repeat_interleave for unknown '
            'repeats rank.')
    if repeats_sizes is None:
        raise RuntimeError(
            'Unsupported: ONNX export of repeat_interleave for unknown '
            'repeats size.')
    if input_sizes is None:
        raise RuntimeError(
            'Unsupported: ONNX export of repeat_interleave for unknown '
            'input size.')
    # Handle cases where dim is negative
    if dim < 0:
        dim += len(input_sizes)

    output_sizes = input_sizes.copy()
    perm_i = [0]
    for idx, input_size in enumerate(input_sizes):
        perm_i.append(idx + 1)
        if input_size is None:
            output_sizes[idx], input_sizes[idx] = 0, -1
    perm_i[0], perm_i[dim] = perm_i[dim], perm_i[0]

    # Cases when repeats is a single value tensor and dim has unknown input size
    if (repeats_dim == 0 or
        (repeats_dim == 1
         and repeats_sizes[0] == 1)) and output_sizes[dim] == 0:
        if not sym_help._is_tensor(repeats):
            repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
        reps = sym_help._size_helper(g, input, dim)
        reps = unsqueeze(g, reps, 0)
        repeats = g.op("Expand", repeats, reps)
    # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
    # provided along one of the dynamic axes provided. A simple example would be
    # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
    # Now, repeat interleaving can be performed in pytorch when the value of * matches
    # with the number of elements in repeat, for example if * -> 2, number of repeats
    # should be 2 as well.
        return torch.onnx.symbolic_opset9.repeat_interleave(
            g, self, repeats, final_dim)

    reps_like = g.op("ConstantOfShape",
                     g.op("Shape", repeats),
                     value_t=torch.tensor([1], dtype=torch.long))
    r_splits = split(g, repeats, reps_like, 0)
    i_splits = split(g, input, reps_like, dim)

    output_sizes[dim], input_sizes[dim] = -1, 1

    # Create a loop to iterate over each value along the dimension
    # and perform individual interleaving using the repeats tensor
    # Loop is of the following pattern
    # input (trip_count, cond)
    #   int trip_count = ...;
    #   bool cond = ...;
    #   for (int i=0; i < trip_count && cond; ++i) {
    #     cond = ...;
    #   }

    # Loop conditions
    loop_condition = g.op("Constant", value_t=torch.tensor(1))
    loop_condition = g.op("Cast", loop_condition, to_i=9)
    loop_len = reps
    loop = g.op("Loop", loop_len, loop_condition)

    # Loop inputs
    loop_block = _add_block(loop.node())
    block_input_iter = _add_input_to_block(loop_block)
    cond = _add_input_to_block(loop_block)

    r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
    i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)

    i_split = unsqueeze(loop_block, i_split, dim + 1)
    r_concat = [
                      value_t=torch.LongTensor(input_sizes[:dim + 1])),
                      value_t=torch.LongTensor(input_sizes[dim + 1:]))
    r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
    i_split = expand(loop_block, i_split, r_concat, None)
    i_split = reshape(loop_block, i_split,
                      g.op("Constant", value_t=torch.LongTensor(output_sizes)))

    # Loop outputs
    cond_out = loop_block.op("Cast", loop_condition, to_i=9)
    _add_output_to_block(loop_block, cond_out)
    _add_output_to_block(loop_block, i_split)
    loop_out = loop.node().output()

    # In this loop, the outputs are scan outputs and are concatenated along
    # the zero'th dimension (by default). In order to avoid this and concatenate
    # along the dimension provided, some post-processing is required
    loop_out = g.op("Transpose", loop_out, perm_i=perm_i)
    return reshape(g, loop_out,
                   g.op("Constant", value_t=torch.LongTensor(output_sizes)))
def repeat_interleave(g, self, repeats, dim=None, output_size=None):
    input = self
    final_dim = dim
    # if dim is None flatten
    # By default, use the flattened input array, and return a flat output array
    if sym_help._is_none(dim):
        input = sym_help._reshape_helper(
            g, self, g.op("Constant", value_t=torch.tensor([-1])))
        dim = 0
        dim = sym_help._maybe_get_scalar(dim)

    repeats_dim = sym_help._get_tensor_rank(repeats)
    repeats_sizes = sym_help._get_tensor_sizes(repeats)
    input_sizes = sym_help._get_tensor_sizes(input)
    if repeats_dim is None:
        raise RuntimeError(
            "Unsupported: ONNX export of repeat_interleave for unknown "
            "repeats rank.")
    if repeats_sizes is None:
        raise RuntimeError(
            "Unsupported: ONNX export of repeat_interleave for unknown "
            "repeats size.")
    if input_sizes is None:
        raise RuntimeError(
            "Unsupported: ONNX export of repeat_interleave for unknown "
            "input size.")
    # Handle cases where dim is negative
    if dim < 0:
        dim += len(input_sizes)

    output_sizes = input_sizes.copy()
    for idx, input_size in enumerate(input_sizes):
        if input_size is None:
            output_sizes[idx], input_sizes[idx] = 0, -1
    cond_dynamic_repeats = (repeats_dim == 1 and repeats_sizes[0] is None)
    # If input size is dynamic or repeats vector is dynamic
    if output_sizes[dim] == 0 or cond_dynamic_repeats:
        reps = sym_help._size_helper(g, input, dim)
        reps = unsqueeze(g, reps, 0)
        # Check if repeats vector is a single integer value
        # or a single dimension tensor with non-dynamic values
        if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
            if not sym_help._is_tensor(repeats):
                repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
            repeats = g.op("Expand", repeats, reps)
        # Check if repeats is dynamic
        # As repeats is dynamic, we use a where node as a substitute for the if statement
        # If repests_dim = 1, expand repeats otherwise use original tensor
        elif cond_dynamic_repeats:
            repeat_dim = sym_help._size_helper(
                g, repeats, g.op("Constant", value_t=torch.LongTensor([0])))
            repeat_cond = g.op("Equal", repeat_dim,
                               g.op("Constant", value_t=torch.LongTensor([1])))
            repeats = where(g, repeat_cond, g.op("Expand", repeats, reps),
    # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
    # provided along one of the dynamic axes provided. A simple example would be
    # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
    # Now, repeat interleaving can be performed in pytorch when the value of * matches
    # with the number of elements in repeat, for example if * -> 2, number of repeats
    # should be 2 as well.
        return torch.onnx.symbolic_opset9.repeat_interleave(
            g, self, repeats, final_dim)

    reps_like = g.op("ConstantOfShape",
                     g.op("Shape", repeats),
                     value_t=torch.tensor([1], dtype=torch.long))
    r_splits = split(g, repeats, reps_like, 0)
    i_splits = split(g, input, reps_like, dim)

    output_sizes[dim], input_sizes[dim] = -1, 1

    # Create a loop to iterate over each value along the dimension
    # and perform individual interleaving using the repeats tensor
    # Loop is of the following pattern
    # input (trip_count, cond)
    #   int trip_count = ...;
    #   bool cond = ...;
    #   for (int i=0; i < trip_count && cond; ++i) {
    #     cond = ...;
    #   }

    # Loop conditions
    loop_condition = g.op("Constant", value_t=torch.tensor(1))
    loop_condition = g.op("Cast", loop_condition, to_i=9)
    loop_len = reps

    # Create an empty sequence to store final expansions
    final_splits = g.op("SequenceEmpty")
    loop = g.op("Loop", loop_len, loop_condition, final_splits)

    # Loop inputs
    loop_block = _add_block(loop.node())
    block_input_iter = _add_input_to_block(loop_block)
    cond = _add_input_to_block(loop_block)
    final_splits = _add_input_to_block(loop_block)

    r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
    i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)

    i_split = unsqueeze(loop_block, i_split, dim + 1)
    r_concat = [
                      value_t=torch.LongTensor(input_sizes[:dim + 1])),
                      value_t=torch.LongTensor(input_sizes[dim + 1:]))
    r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
    i_split = expand(loop_block, i_split, r_concat, None)
    i_split = sym_help._reshape_helper(
        loop_block, i_split,
        g.op("Constant", value_t=torch.LongTensor(output_sizes)))
    final_splits = loop_block.op("SequenceInsert", final_splits, i_split)

    # Loop outputs
    cond_out = loop_block.op("Cast", loop_condition, to_i=9)
    _add_output_to_block(loop_block, cond_out)
    _add_output_to_block(loop_block, final_splits)

    loop_out = loop.node().output()
    loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
    return loop_out
def index_put(g, self, indices_list_value, values, accumulate=False):
    indices_list = sym_help._unpack_list(indices_list_value)
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
                g, expand(g, ind, broadcast_index_shape, None), [-1])
            for ind in indices_list
        index = g.op("Concat", *indices_list, axis_i=-1)
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains boolean inputs
        # index_put -> masked_fill
        # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %6 : None = prim::Constant()
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %8 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::ne(%mask, %some_const)
        #   %26 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]()
        #   %27 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %11 : Device = prim::Constant[value="cpu"]()
        #   %12 : None = prim::Constant()
        #   %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %15 : None = prim::Constant()
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %22 : int[] = prim::Constant[value=[-1]]()
        #   %23 : Tensor = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %3 : Tensor = onnx::Equal(%0, %some_const)
        #   %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3)
        #   %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4)
        #   %19 : Tensor = onnx::Cast[to=9](%12)
        #   %20 : Tensor = onnx::Constant[value={1}]()
        #   %21 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = onnx::Where(%19, %20, %0)
        #   return (%21)
        # index_put -> masked_scatter
        # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %6 : None = prim::Constant()
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %34 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]()
        #   %35 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %18 : Device = prim::Constant[value="cpu"]()
        #   %19 : None = prim::Constant()
        #   %36 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %22 : None = prim::Constant()
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Tensor = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %3 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #               = onnx::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %4 : Tensor = onnx::Equal(%0, %some_const)
        #   %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4)
        #   %13 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%5)
        #   %19 : Tensor = onnx::Shape(%0)
        #   %20 : Tensor = onnx::Expand(%13, %19)
        #   %21 : Tensor = onnx::NonZero(%20)
        #   %22 : Tensor = onnx::Transpose[perm=[1, 0]](%21)
        #   %23 : Tensor = onnx::Constant[value={-1}]()
        #   %24 : Tensor = onnx::Reshape(%3, %23)
        #   %25 : Tensor = onnx::Shape(%22)
        #   %27 : Tensor = onnx::Constant[value={0}]()
        #   %28 : Tensor = onnx::Gather[axis=0](%25, %27)
        #   %29 : Tensor = onnx::Constant[value={0}]()
        #   %30 : Tensor = onnx::Unsqueeze[axes=[0]](%29)
        #   %31 : Tensor = onnx::Unsqueeze[axes=[0]](%28)
        #   %32 : Tensor = onnx::Constant[value={0}]()
        #   %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32)
        #   %34 : Tensor = onnx::Slice(%24, %30, %31, %33)
        #   %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = onnx::ScatterND(%0, %22, %34)
        #   return (%35)

        bool_inp = list(index.node().inputs())[0]
        if bool_inp.type() is not None and bool_inp.type().scalarType(
        ) == 'Bool':
            rank = sym_help._get_tensor_rank(values)
            if rank is not None and rank == 0:
                from torch.onnx.symbolic_opset9 import masked_fill
                return masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = sym_help._unsqueeze_helper(g, index, [-1])
    sub_data_shape = sym_help._slice_helper(g,
                                            g.op("Shape", self),
    values_shape = g.op("Concat",
    values = g.op("Reshape", values, values_shape)

    if accumulate:
        dtype = self.type().scalarType()
        dtype = sym_help.scalar_type_to_onnx.index(
        dtype = sym_help.scalar_type_to_pytorch_type[dtype]
        zeros = g.op("ConstantOfShape",
                     g.op("Shape", self),
                     value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
        result = g.op("ScatterND", self, index, values)

    return result
def index_put(g, self, indices_list_value, values, accumulate=False):
    if symbolic_helper._is_packed_list(indices_list_value):
        indices_list = symbolic_helper._unpack_list(indices_list_value)
        indices_list = [indices_list_value]
    if symbolic_helper.is_caffe2_aten_fallback():
        args = [self] + indices_list + [values, accumulate]
        return g.at("index_put", *args)

    accumulate = symbolic_helper._parse_arg(accumulate, "b")

    if len(indices_list) == 0:
        return values

    if len(indices_list) > 1:
        for idx_ in range(len(indices_list)):
            if indices_list[idx_].type().scalarType() == "Bool":  # type: ignore[attr-defined]
                # TODO(justinchuby): Remove type ignore after #81112 is checked in.
                indices_list[idx_] = g.op("NonZero", indices_list[idx_])
        index = indices_list[0]

        for ind in indices_list[1:]:
            index = opset9.add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
                g, opset9.expand(g, ind, broadcast_index_shape, None), [-1]
            for ind in indices_list
        index = g.op("Concat", *indices_list, axis_i=-1)
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains a single boolean input.
        # index_put -> masked_fill
        #   * input index contains single tensor of Bool type (e.g.: %24 <- %23).
        #   * input value contains single element (e.g.: %18).
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        # index_put -> masked_scatter
        #   * input index contains single tensor of Bool type (e.g.: %32 <- %31).
        #   * input value contains multiple elements (e.g.: %28).
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        index = indices_list[0]
        bool_inp = index
        if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool":  # type: ignore[attr-defined]
            # TODO(justinchuby): Remove type ignore after #81112 is checked in.
            rank = symbolic_helper._get_tensor_rank(values)
            if rank is not None and rank == 0:
                return opset9.masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = symbolic_helper._unsqueeze_helper(g, index, [-1])
    sub_data_shape = symbolic_helper._slice_helper(
        g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize]
    values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
    # Check if values is a singular value and expand accordingly
    rank = symbolic_helper._get_tensor_rank(values)
    if rank is not None and rank == 0:
        values = opset9.expand(g, values, values_shape, None)
    values = symbolic_helper._reshape_helper(g, values, values_shape)

    dtype = self.type().scalarType()
    if dtype is not None and dtype != values.type().scalarType():
        values = g.op("Cast", values, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
    dtype = symbolic_helper.scalar_type_to_onnx.index(
    dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype]

    if accumulate:
        zeros = g.op(
            g.op("Shape", self),
            value_t=torch.tensor([0], dtype=dtype),
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
        result = g.op("ScatterND", self, index, values)

    return result