Exemple #1
0
def frobenius_norm(g, self, dim=None, keepdim=False):
    dim_val = sym_help._maybe_get_const(dim, "is")
    if not sym_help._is_value(dim_val) and len(dim_val) == 0:
        return g.op("ReduceL2", self, keepdims_i=0)
    sqr = g.op("Mul", self, self)
    sumsqr = sym_help._reducesum_helper(g, sqr, dim, keepdims_i=keepdim)
    return g.op("Sqrt", sumsqr)
def embedding_bag(g,
                  embedding_matrix,
                  indices,
                  offsets,
                  scale_grad_by_freq,
                  mode,
                  sparse,
                  per_sample_weights,
                  include_last_offset,
                  padding_idx):
    if scale_grad_by_freq and sym_help._training_mode:
        return sym_help._onnx_unsupported("embedding_bag with scale_grad_by_freq for training mode")
    if padding_idx is not None and padding_idx >= 0:
        raise RuntimeError("embedding_bag with padding_idx")
    from torch.onnx.symbolic_opset9 import select
    import warnings
    warnings.warn("Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. "
                  "Please use opset 11 or higher to export model for dynamic input shape.'")
    offsets_dim_0 = sym_help._get_tensor_dim_size(offsets, 0)
    if offsets_dim_0 is not None:
        if include_last_offset:
            offset_len = offsets_dim_0 - 1
            offsets_extended = offsets
        else:
            offset_len = offsets_dim_0
            offsets_extended = [offsets, g.op("Constant", value_t=torch.tensor([maxsize]))]
            offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
        list_ = []
        for i in range(offset_len):
            start_ = sym_help._unsqueeze_helper(g, select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), [0])
            end_ = sym_help._unsqueeze_helper(g, select(g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)), [0])
            axes_ = g.op("Constant", value_t=torch.tensor([0]))
            indices_row = g.op("Slice", indices, start_, end_, axes_)

            embeddings = g.op("Gather", embedding_matrix, indices_row)
            if not sym_help._is_none(per_sample_weights):
                per_sample_weights_row = g.op("Slice", per_sample_weights, start_, end_, axes_)
                per_sample_weights_row = sym_help._unsqueeze_helper(g, per_sample_weights_row, [1])
                embeddings = g.op("Mul", embeddings, per_sample_weights_row)
            if mode == 0:
                embeddings = sym_help._reducesum_helper(g, embeddings, axes_i=[0], keepdims_i=0)
            elif mode == 1:
                embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
            else:
                embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)

            embeddings = sym_help._unsqueeze_helper(g, embeddings, [0])
            list_.append(embeddings)

        output = g.op("Concat", *list_, axis_i=0)
        # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
        # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
        return output, None, None, None
    else:
        return sym_help._onnx_unsupported("embedding_bag with unknown shape of offsets for opset 10 is not supported. "
                                          "please use opset 11 or higher.")
Exemple #3
0
def linalg_vector_norm(g, self, ord, dim, keepdim, dtype):
    if ord == 0:
        if dim is None:
            self = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)))
            keepdim = None
        cond_op = g.op("Not", g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0]))))
        cond_op = g.op("Cast", cond_op, to_i=sym_help.cast_pytorch_to_onnx["Long"])
        return sym_help._reducesum_helper(g, cond_op, axes_i=dim, keepdims_i=keepdim)
    else:
        return lvn(g, self, ord, dim, keepdim, dtype)
Exemple #4
0
def linalg_vector_norm(g, self, ord, dim, keepdim, dtype):
    if ord == 0:
        if dim is None:
            self = symbolic_helper._reshape_helper(
                g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
            )
            keepdim = 0

        cond_op = g.op(
            "Not", g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0])))
        )
        cond_op = g.op(
            "Cast",
            cond_op,
            to_i=symbolic_helper.cast_pytorch_to_onnx[self.type().scalarType()],
        )
        return symbolic_helper._reducesum_helper(
            g, cond_op, axes_i=dim, keepdims_i=keepdim
        )
    else:
        return opset9.linalg_vector_norm(g, self, ord, dim, keepdim, dtype)
def embedding_bag(g, embedding_matrix, indices, offsets, scale_grad_by_freq,
                  mode, sparse, per_sample_weights, include_last_offset,
                  padding_idx):
    if scale_grad_by_freq and sym_help._training_mode:
        return sym_help._onnx_unsupported(
            'embedding_bag with scale_grad_by_freq for training mode')
    if padding_idx is not None and padding_idx >= 0:
        raise RuntimeError('embedding_bag with padding_idx')

    loop_condition = g.op("Constant", value_t=torch.tensor(1))
    loop_condition = g.op("Cast", loop_condition, to_i=9)
    zero = g.op("Constant", value_t=torch.tensor([0]))

    indices_len = sym_help._unsqueeze_helper(
        g,
        sym_help._size_helper(g, indices,
                              g.op("Constant", value_t=torch.tensor(0))), [0])
    if not include_last_offset:
        offsets = [offsets, indices_len]
        offsets = g.op("Concat", *offsets, axis_i=0)

    # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by
    # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings.
    # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in.
    offsets_starts = sym_help._slice_helper(g,
                                            offsets,
                                            axes=[0],
                                            starts=[0],
                                            ends=[maxsize],
                                            steps=[1])
    offsets_ends = sym_help._slice_helper(g,
                                          offsets,
                                          axes=[0],
                                          starts=[1],
                                          ends=[maxsize],
                                          steps=[1])

    loop_len = sym_help._size_helper(g, offsets_ends,
                                     g.op("Constant", value_t=torch.tensor(0)))
    loop = g.op("Loop", loop_len, loop_condition)

    loop_block = _add_block(loop.node())
    block_input_iter = _add_input_to_block(loop_block)
    cond = _add_input_to_block(loop_block)

    indices_start = loop_block.op("Gather",
                                  offsets_starts,
                                  block_input_iter,
                                  axis_i=0)
    indices_end = loop_block.op("Gather",
                                offsets_ends,
                                block_input_iter,
                                axis_i=0)
    indices_start = sym_help._unsqueeze_helper(loop_block, indices_start, [0])
    indices_end = sym_help._unsqueeze_helper(loop_block, indices_end, [0])

    indices_row = loop_block.op("Slice", indices, indices_start, indices_end,
                                zero)
    embeddings = loop_block.op("Gather",
                               embedding_matrix,
                               indices_row,
                               axis_i=0)
    if not sym_help._is_none(per_sample_weights):
        per_sample_weights_row = loop_block.op("Slice", per_sample_weights,
                                               indices_start, indices_end,
                                               zero)
        per_sample_weights_row = sym_help._unsqueeze_helper(
            loop_block, per_sample_weights_row, [1])
        embeddings = loop_block.op("Mul", embeddings, per_sample_weights_row)
    if mode == 0:
        embeddings = sym_help._reducesum_helper(loop_block,
                                                embeddings,
                                                axes_i=[0],
                                                keepdims_i=0)
    elif mode == 1:
        embeddings = loop_block.op("ReduceMean",
                                   embeddings,
                                   axes_i=[0],
                                   keepdims_i=0)
    else:
        embeddings = loop_block.op("ReduceMax",
                                   embeddings,
                                   axes_i=[0],
                                   keepdims_i=0)

    cond_out = loop_block.op("Cast", loop_condition, to_i=9)
    _add_output_to_block(loop_block, cond_out)
    _add_output_to_block(loop_block, embeddings)

    # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
    # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
    return loop.node().output(), None, None, None
def einsum(g, equation, tensor_list):
    tensors = sym_help._unpack_list(tensor_list)
    num_ops = len(tensors)
    assert num_ops > 0

    # Doesn't support implicit output is ellipsis or more than 2 oprands for now.
    # Doesn't support ellipsis ('...') for now as not easy to get sizes of oprands.
    if num_ops != 2 or equation.find("->") == -1 or "." in equation:
        return g.op("Einsum", *tensors, equation_s=equation)

    # Take "ks,ksm->sm" as example. After prcoess inputs,
    # lhs_labels = [k,s], rhs_labels = [k,s,m], result_labels = [s,m].
    lhs_labels, rhs_labels, result_labels = parse_equation(equation)

    # Doesn't support repeated label in operand for now as it needs to take extra diagonal.
    if len(lhs_labels) != len(set(lhs_labels)) or len(rhs_labels) != len(
            set(rhs_labels)):
        return g.op("Einsum", *tensors, equation_s=equation)

    # Add contraction labels (labels not present in output).
    # After process contraction labels, contraction_labels = [k],
    # label_perm_map = {(s, 0), (m, 1), (k, 2)}, out_size = 2, perm_size = 3.
    out_size = len(result_labels)
    label_perm_map = dict([(label, idx)
                           for idx, label in enumerate(result_labels)])
    perm_size = out_size
    contraction_labels = []
    lhs_reduce_sum_axes = []
    rhs_reduce_sum_axes = []
    for label in lhs_labels + rhs_labels:
        if label not in label_perm_map:
            if label in lhs_labels and label in rhs_labels:
                label_perm_map[label] = perm_size
                contraction_labels.append(label)
                perm_size += 1
            elif label in lhs_labels:
                lhs_reduce_sum_axes.append(lhs_labels.index(label))
            else:
                rhs_reduce_sum_axes.append(rhs_labels.index(label))

    lhs_tensor = tensors[0]
    rhs_tensor = tensors[1]

    # If lhs_reduce_sum_axes/rhs_reduce_sum_axes is not empty, ReduceSum on that axes, update lhs_labels/rhs_labels,
    # and use the output as original_lhs_tensor/original_rhs_tensor.
    if lhs_reduce_sum_axes:
        lhs_tensor = sym_help._reducesum_helper(g,
                                                lhs_tensor,
                                                lhs_reduce_sum_axes,
                                                keepdims_i=False)
        lhs_labels = [
            lhs_labels[axis] for axis in range(len(lhs_labels))
            if axis not in lhs_reduce_sum_axes
        ]

    if rhs_reduce_sum_axes:
        rhs_tensor = sym_help._reducesum_helper(g,
                                                rhs_tensor,
                                                rhs_reduce_sum_axes,
                                                keepdims_i=False)
        rhs_labels = [
            rhs_labels[axis] for axis in range(len(rhs_labels))
            if axis not in rhs_reduce_sum_axes
        ]

    # Need to unsqueeze and permute the inputs to order of output with contraction labels.
    # lhs_perm = [1,2,0], lhs_unsqueeze_axes = [2].
    # rhs_perm = [1,2,0], rhs_unsqueeze_axes = [].
    lhs_perm, lhs_unsqueeze_axes = map_labels_to_output(
        lhs_labels, label_perm_map)
    rhs_perm, rhs_unsqueeze_axes = map_labels_to_output(
        rhs_labels, label_perm_map)

    # If there is no contraction labels, unsqueeze and permute the inputs and Mul them to get final result.
    if not contraction_labels:
        lhs_tensor = unsqueeze_and_permute_for_mul(g, lhs_tensor,
                                                   lhs_unsqueeze_axes,
                                                   lhs_perm)
        rhs_tensor = unsqueeze_and_permute_for_mul(g, rhs_tensor,
                                                   rhs_unsqueeze_axes,
                                                   rhs_perm)
        return g.op("Mul", lhs_tensor, rhs_tensor)

    # If contraction_labels is not empty, need a BatchedMatMul.
    # Batched labels are those in all inputs and output. Below axes are based on output.
    # batched_labels = [s], batched_axes = [0] for the example.
    # Matmul output labels are those in one of inputs and output.
    # matmul_output_labels = [m], matmul_output_axes = [1] for the example.
    # contraction_labels = [k], contraction_axes = [2] for the example.
    batched_axes = []
    matmul_output_axes = []
    contraction_axes = [axis for axis in range(out_size, perm_size)]
    for axis in range(out_size):
        label = result_labels[axis]
        if label in lhs_labels and label in rhs_labels:
            batched_axes.append(axis)
        else:
            matmul_output_axes.append(axis)

    # Based on above unsqueeze and permute on inputs, need to permute again.
    # For lhs input, the new permute is batched_axes + matmul_output_axes + contraction_axes: [0, 1, 2],
    # i.e., a.unsqueeze([2]).permute([1,2,0]).permute([0,1,2]) = [s,1,k] for the example.
    # For rhs input, the new permute is batched_axes + contraction_axes + matmul_output_axes: [0, 2, 1].
    # i.e., b.unsqueeze([]).permute([1,2,0]).permute([0,2,1]) = [s,k,m] for the example.
    lhs_perm = combine_unsqueeze_and_permute_for_matmul(
        lhs_unsqueeze_axes, lhs_perm,
        batched_axes + matmul_output_axes + contraction_axes)
    rhs_perm = combine_unsqueeze_and_permute_for_matmul(
        rhs_unsqueeze_axes, rhs_perm,
        batched_axes + contraction_axes + matmul_output_axes)

    # Need to Reshape two input tensors before the BatchedMatMul and Reshape result to output shape.
    # Reshape lhs input to [[batched_shapes], Mul(lhs_matmul_output_shapes), Mul(contraction_shapes)].
    # Reshape rhs input to [[batched_shapes], Mul(contraction_shapes), Mul(rhs_matmul_output_shapes)]
    # Convert all axes based on inputs.
    # lhs_contraction_axes = [0], rhs_contraction_axes = [0], lhs_matmul_output_axes = [], rhs_matmul_output_axes = [2] for the example.
    lhs_contraction_axes = [
        lhs_labels.index(label) for label in contraction_labels
    ]
    rhs_contraction_axes = [
        rhs_labels.index(label) for label in contraction_labels
    ]
    lhs_matmul_output_axes = [
        lhs_labels.index(result_labels[axis]) for axis in matmul_output_axes
        if result_labels[axis] in lhs_labels
    ]
    rhs_matmul_output_axes = [
        rhs_labels.index(result_labels[axis]) for axis in matmul_output_axes
        if result_labels[axis] in rhs_labels
    ]

    # Caches of input shape tensors to avoid generating duplicated graph.
    lhs_shape_tensor = None
    rhs_shape_tensor = None

    # contraction_numel_tensor should be tensor([size(k)]) for the example, but since length is 1, it's None here.
    contraction_numel_tensor = None
    if len(lhs_contraction_axes) > 1:
        _, contraction_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes(
            g, lhs_tensor, lhs_shape_tensor, lhs_contraction_axes, True)

    # Prepare some shape tensors for Reshape if needed.
    # Both lhs_matmul_output_shape_tensor and lhs_matmul_output_numel_tensor is None for the example.
    lhs_matmul_output_shape_tensor = None
    lhs_matmul_output_numel_tensor = None
    if len(lhs_matmul_output_axes) > 1:
        lhs_matmul_output_shape_tensor, lhs_matmul_output_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes(
            g, lhs_tensor, lhs_shape_tensor, lhs_matmul_output_axes, True)

    # Both rhs_matmul_output_shape_tensor and rhs_matmul_output_numel_tensor is None for the example.
    rhs_matmul_output_shape_tensor = None
    rhs_matmul_output_numel_tensor = None
    if len(rhs_matmul_output_axes) > 1:
        rhs_matmul_output_shape_tensor, rhs_matmul_output_numel_tensor, rhs_shape_tensor = get_shape_tensor_by_axes(
            g, rhs_tensor, rhs_shape_tensor, rhs_matmul_output_axes, True)

    new_lhs_tensor = lhs_tensor
    # Need to Reshape lhs_tensor if lhs_matmul_output_axes or lhs_contraction_axes is not 1, otherwise permute it directly.
    # Need to Reshape the lhs_tensor for the example, the new shape is [size(s), 1, size(k)].
    if len(lhs_matmul_output_axes) != 1 or len(lhs_contraction_axes) != 1:
        new_lhs_tensor, lhs_shape_tensor = permute_and_reshape_tensor(
            g,
            lhs_tensor,
            True,
            len(lhs_labels),
            lhs_perm,
            lhs_matmul_output_axes,
            lhs_contraction_axes,
            len(batched_axes),
            lhs_matmul_output_numel_tensor,
            contraction_numel_tensor,
            lhs_shape_tensor,
        )
    else:
        if need_permute(lhs_perm):
            new_lhs_tensor = g.op("Transpose", lhs_tensor, perm_i=lhs_perm)

    # Need to Reshape rhs_tensor if rhs_matmul_output_axes or rhs_contraction_axes is not 1, otherwise permute it directly.
    # rhs_tensor's new shape should be [size(s), size(k), size(m)], but doesn't need to Reshape for the example.
    new_rhs_tensor = rhs_tensor
    if len(rhs_matmul_output_axes) != 1 or len(rhs_contraction_axes) != 1:
        new_rhs_tensor, rhs_shape_tensor = permute_and_reshape_tensor(
            g,
            rhs_tensor,
            False,
            len(rhs_labels),
            rhs_perm,
            rhs_matmul_output_axes,
            rhs_contraction_axes,
            len(batched_axes),
            rhs_matmul_output_numel_tensor,
            contraction_numel_tensor,
            rhs_shape_tensor,
        )
    else:
        if need_permute(rhs_perm):
            new_rhs_tensor = g.op("Transpose", rhs_tensor, perm_i=rhs_perm)

    # Perform final BatchedMatMul. Output is shape [size(s), 1, size(m)] for the example.
    result = g.op("MatMul", new_lhs_tensor, new_rhs_tensor)

    # Need to Reshape the result if lhs_matmul_output_axes or rhs_matmul_output_axes is not 1.
    # Need to Reshape the result for the example, the new shape is [size(s), size(m)].
    if len(lhs_matmul_output_axes) != 1 or len(rhs_matmul_output_axes) != 1:
        shape_tensors = [
            g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))
        ] * len(batched_axes)
        last_zero_dim = len(shape_tensors) - 1
        has_neg_one_dim = False
        if lhs_matmul_output_axes:
            if len(lhs_matmul_output_axes) == 1:
                shape_tensors.append(
                    g.op("Constant",
                         value_t=torch.tensor([0], dtype=torch.int64)))
                last_zero_dim = len(shape_tensors) - 1
            else:
                shape_tensors.append(lhs_matmul_output_shape_tensor)
        if rhs_matmul_output_axes:
            if len(rhs_matmul_output_axes) == 1:
                shape_tensors.append(
                    g.op("Constant",
                         value_t=torch.tensor([-1], dtype=torch.int64)))
                has_neg_one_dim = True
            else:
                shape_tensors.append(rhs_matmul_output_shape_tensor)
        if not has_neg_one_dim and last_zero_dim >= 0:
            shape_tensors[last_zero_dim] = g.op("Constant",
                                                value_t=torch.tensor(
                                                    [-1], dtype=torch.int64))
        result = reshape_tensor(g, result, shape_tensors)

    # Now output axes is ordered by [batched_axes, lhs_matmul_output_axes, rhs_matmut_output_axes],
    # if this is not same as output, need one permute.
    labels = ([result_labels[axis] for axis in batched_axes] +
              [lhs_labels[axis] for axis in lhs_matmul_output_axes] +
              [rhs_labels[axis] for axis in rhs_matmul_output_axes])
    assert len(labels) == out_size
    output_perm = [labels.index(label) for label in result_labels]
    assert all(axis in output_perm for axis in range(out_size))
    if need_permute(output_perm):
        result = g.op("Transpose", result, perm_i=output_perm)

    return result
def diagonal(g, self, offset, dim1, dim2):
    dim1_size = opset9.size(
        g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1]))
    )
    dim2_size = opset9.size(
        g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2]))
    )

    # Create appropriate mask
    mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0)
    mask = opset9.zeros(g, mask_shape, None, None, None)
    mask = g.op("EyeLike", mask, k_i=offset)

    # dim1 and dim2 appended as a dimension at the end of the shape
    rank = symbolic_helper._get_tensor_rank(self)
    if rank is not None:
        axes = list(range(rank))
        axes.remove(dim1)
        axes.remove(dim2)
        self = g.op("Transpose", self, perm_i=axes + [dim1, dim2])
    else:
        return symbolic_helper._unimplemented("diagonal", "unknown input rank")

    # Multiply input and mask to calculate values along diagonal
    # The mask consists of one values where diagonal values are to be calculated
    # For example:
    # [[1.1, 1.2, 1.3],   *    [[1, 0, 0]   =   [[1.1, 0, 0],
    #  [2.1, 2.2, 2.3],         [0, 1, 0]        [0, 2.2, 0],
    #  [3.1, 3.2, 3.3]]         [0, 0, 1]]       [0, 0, 3.3]]
    result = g.op("Mul", self, mask)
    result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0)

    # Calculate gather indices based on offset and dims
    # If offset is greater than zero, set offset to zero as this aids in
    # calculation of selection window
    offset_op = g.op("Constant", value_t=torch.LongTensor([offset]))
    if offset >= 0:
        diag_size = g.op(
            "Max",
            g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)),
            g.op("Constant", value_t=torch.LongTensor([0])),
        )
        offset = 0
    else:
        diag_size = g.op(
            "Max",
            g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size),
            g.op("Constant", value_t=torch.LongTensor([0])),
        )
    diag_size = g.op("Concat", diag_size, axis_i=0)

    # Calculate which diagonal values to select
    # For example, in cases with offsets:
    # [[0, 1.1, 0]
    #  [0, 0, 2.2]]
    # we need to select the last two columns, so we create a tensor
    # with all columns that are to be selected
    # So in this example, it is [1, 2]
    select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None)
    select_window = g.op(
        "CumSum",
        select_window_ones_fill,
        g.op("Constant", value_t=torch.LongTensor([0])),
    )
    select_window = g.op(
        "Add",
        select_window,
        g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])),
    )

    gather_shape = [
        opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis])))
        for axis in list(range(rank))[:-2]
    ]
    gather_shape.append(diag_size)
    gather_shape = g.op("Concat", *gather_shape, axis_i=0)
    gather_indices = opset9.zeros(g, gather_shape, 4, None, None)

    # There might be cases where offset value is greater than number of rows/columns
    # and might cause the diagonal to overrun and as a result of this, diag_size would be zero.
    # For example, if
    #       offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows)
    #       diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above
    # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0
    # In cases without diagonal overrun, we select the appropriate rows/columns along which we
    # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has
    # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially
    # returning an empty tensor
    overrun_cond = g.op(
        "Not",
        g.op(
            "Equal",
            diag_size,
            g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)),
        ),
    )
    if_op = g.op("If", overrun_cond)
    if_node = if_op.node()

    if_block = utils._add_block(if_node)
    gather_indices_if_block = if_block.op("Add", gather_indices, select_window)
    gather_indices_if_block = symbolic_helper._unsqueeze_helper(
        if_block, gather_indices_if_block, [rank - 1]
    )
    final_non_overrun_ = if_block.op(
        "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2
    )
    utils._add_output_to_block(if_block, final_non_overrun_)

    else_block = utils._add_block(if_node)
    final_overrun_ = opset9.zeros(else_block, gather_shape, 6, None, None)
    utils._add_output_to_block(else_block, final_overrun_)
    return if_op