Exemple #1
0
def flatten(g, input, start_dim, end_dim):
    dim = sym_help._get_tensor_rank(input)
    # use ONNX's Flatten operator for cases where the output shape is 2D
    if start_dim == 1:
        if (end_dim == -1 or (dim is not None and end_dim == dim - 1)):
            return g.op("Flatten", input, axis_i=start_dim)
    elif start_dim == 0:
        if (end_dim == -2 or (dim is not None and end_dim == dim - 2)):
            return g.op("Flatten", input, axis_i=end_dim + 1)
    if dim is None:
        return _unimplemented("dim",
                              "ONNX and PyTorch use different strategies to split the input. "
                              "Input rank must be known at export time.")
    # if end_dim is negative add dim
    if end_dim < 0 :
        end_dim = dim + end_dim

    return sym_help._flatten_helper(g, input, start_dim, end_dim, dim)
Exemple #2
0
def squeeze(g, self, dim=None):
    if dim is None:
        return g.op("Squeeze", self)

    # dim as a tensor
    if not sym_help._is_constant(dim):
        return sym_help._squeeze_helper(g, self, [dim])

    dim = sym_help._get_const(dim, "i", "dim")

    input_rank = sym_help._get_tensor_rank(self)
    adjusted_dim = dim
    if input_rank is not None and dim < 0:
        adjusted_dim += input_rank
    dim_size = sym_help._get_tensor_dim_size(self, adjusted_dim)
    if (dim < 0 and input_rank is None) or dim_size is None:
        # If onnx shape inference is not on, export always as dynamic.
        # Because we cannot tell if observed static shape is also static at runtime.
        # create "cond" node (condition is shape[i]==1)
        dim_constant = g.op("Constant", value_t=torch.tensor([dim]))
        size = sym_help._size_helper(g, self, dim_constant)
        const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))
        cond = g.op("Equal", size, const_one)
        # create the "If" node and add the "then" and "else" blocks to it.
        if_node_outputs = g.op("If", cond)
        if_node = if_node_outputs.node()
        if_block = torch.onnx.utils._add_block(if_node)
        squeeze_ = sym_help._squeeze_helper(if_block, self, [dim])
        torch.onnx.utils._add_output_to_block(if_block, squeeze_)
        else_block = torch.onnx.utils._add_block(if_node)
        identity_ = else_block.op("Identity", self)
        torch.onnx.utils._add_output_to_block(else_block, identity_)
        return if_node_outputs

    # For static input shape
    dim = adjusted_dim
    if dim_size > 1:
        warnings.warn("This model contains a squeeze operation on dimension " + str(dim) + ". The size of " +
                      "this dimension in the given input is " + str(dim_size) + ". The model will " +
                      "be exported without the squeeze node. If the model is intended to be used with dynamic " +
                      "input shapes, please export with dynamic_axes argument.")
        return self
    return sym_help._squeeze_helper(g, self, [dim])
Exemple #3
0
def _prepare_onnx_paddings(g, input, pad):
    # The desired order of paddings is
    # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
    # n is the dimension of input.
    # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning
    pad_len = torch.onnx.symbolic_opset9.size(
        g, pad, g.op("Constant", value_t=torch.tensor([0])))
    # Set extension = [0] * (dim * 2 - len(pad))
    rank = sym_help._get_tensor_rank(input)
    if rank is None:
        rank = g.op("Size", g.op("Shape", input))
    else:
        rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64))
    extension = g.op(
        "Sub",
        g.op("Mul", rank,
             g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))),
        pad_len)
    # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ]
    # Currently ONNX only supports int64 type for Pad
    pad = g.op("Cast", pad, to_i=sym_help.cast_pytorch_to_onnx["Long"])
    paddings = g.op("Concat",
                    pad,
                    g.op("ConstantOfShape",
                         extension,
                         value_t=torch.tensor([0], dtype=torch.int64)),
                    axis_i=0)
    # Reshape and reverse order and collate first beginnings and then ends
    # paddings = [[..., 0, dim_n-1_begin, dim_n_begin],
    #               [..., 0, dim_n-1_end, dim_n_end]]
    # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end]
    paddings = sym_help._reshape_helper(
        g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2])))
    paddings = g.op("Transpose",
                    torch.onnx.symbolic_opset10.flip(g, paddings, [0]),
                    perm_i=[1, 0])
    paddings = sym_help._reshape_helper(
        g, paddings, g.op("Constant", value_t=torch.tensor([-1])))
    padding_c = g.op("Cast",
                     paddings,
                     to_i=sym_help.cast_pytorch_to_onnx["Long"])
    return padding_c
def constant_pad_nd(g, input, padding, value=None):
    mode = "constant"
    value = sym_help._maybe_get_scalar(value)
    value = sym_help._if_scalar_type_as(g, value, input)
    pad = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding)
    return g.op("Pad", input, pad, value, mode_s=mode)
def pixel_shuffle(g, self, upscale_factor):
    rank = sym_help._get_tensor_rank(self)
    if rank is not None and rank != 4:
        return _unimplemented("pixel_shuffle", "only support 4d input")
    return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD")
def repeat_interleave(g, self, repeats, dim=None):
    from torch.onnx.symbolic_opset9 import reshape
    input = self
    final_dim = dim
    # if dim is None flatten
    # By default, use the flattened input array, and return a flat output array
    if sym_help._is_none(dim):
        input = reshape(g, self, g.op("Constant", value_t=torch.tensor([-1])))
        dim = 0
    else:
        dim = sym_help._maybe_get_scalar(dim)

    repeats_dim = sym_help._get_tensor_rank(repeats)
    repeats_sizes = sym_help._get_tensor_sizes(repeats)
    input_sizes = sym_help._get_tensor_sizes(input)
    if repeats_dim is None:
        raise RuntimeError(
            'Unsupported: ONNX export of repeat_interleave for unknown '
            'repeats rank.')
    if repeats_sizes is None:
        raise RuntimeError(
            'Unsupported: ONNX export of repeat_interleave for unknown '
            'repeats size.')
    if input_sizes is None:
        raise RuntimeError(
            'Unsupported: ONNX export of repeat_interleave for unknown '
            'input size.')
    # Handle cases where dim is negative
    if dim < 0:
        dim += len(input_sizes)

    output_sizes = input_sizes.copy()
    perm_i = [0]
    for idx, input_size in enumerate(input_sizes):
        perm_i.append(idx + 1)
        if input_size is None:
            output_sizes[idx], input_sizes[idx] = 0, -1
    perm_i[0], perm_i[dim] = perm_i[dim], perm_i[0]

    # Cases when repeats is a single value tensor and dim has unknown input size
    if (repeats_dim == 0 or
        (repeats_dim == 1
         and repeats_sizes[0] == 1)) and output_sizes[dim] == 0:
        if not sym_help._is_tensor(repeats):
            repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
        reps = sym_help._size_helper(g, input, dim)
        reps = unsqueeze(g, reps, 0)
        repeats = g.op("Expand", repeats, reps)
    # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
    # provided along one of the dynamic axes provided. A simple example would be
    # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
    # Now, repeat interleaving can be performed in pytorch when the value of * matches
    # with the number of elements in repeat, for example if * -> 2, number of repeats
    # should be 2 as well.
    else:
        return torch.onnx.symbolic_opset9.repeat_interleave(
            g, self, repeats, final_dim)

    reps_like = g.op("ConstantOfShape",
                     g.op("Shape", repeats),
                     value_t=torch.tensor([1], dtype=torch.long))
    r_splits = split(g, repeats, reps_like, 0)
    i_splits = split(g, input, reps_like, dim)

    output_sizes[dim], input_sizes[dim] = -1, 1

    # Create a loop to iterate over each value along the dimension
    # and perform individual interleaving using the repeats tensor
    # Loop is of the following pattern
    # input (trip_count, cond)
    #   int trip_count = ...;
    #   bool cond = ...;
    #   for (int i=0; i < trip_count && cond; ++i) {
    #     cond = ...;
    #   }

    # Loop conditions
    loop_condition = g.op("Constant", value_t=torch.tensor(1))
    loop_condition = g.op("Cast", loop_condition, to_i=9)
    loop_len = reps
    loop = g.op("Loop", loop_len, loop_condition)

    # Loop inputs
    loop_block = _add_block(loop.node())
    block_input_iter = _add_input_to_block(loop_block)
    cond = _add_input_to_block(loop_block)

    r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
    i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)

    i_split = unsqueeze(loop_block, i_split, dim + 1)
    r_concat = [
        loop_block.op("Constant",
                      value_t=torch.LongTensor(input_sizes[:dim + 1])),
        r_split,
        loop_block.op("Constant",
                      value_t=torch.LongTensor(input_sizes[dim + 1:]))
    ]
    r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
    i_split = expand(loop_block, i_split, r_concat, None)
    i_split = reshape(loop_block, i_split,
                      g.op("Constant", value_t=torch.LongTensor(output_sizes)))

    # Loop outputs
    cond_out = loop_block.op("Cast", loop_condition, to_i=9)
    _add_output_to_block(loop_block, cond_out)
    _add_output_to_block(loop_block, i_split)
    loop_out = loop.node().output()

    # In this loop, the outputs are scan outputs and are concatenated along
    # the zero'th dimension (by default). In order to avoid this and concatenate
    # along the dimension provided, some post-processing is required
    loop_out = g.op("Transpose", loop_out, perm_i=perm_i)
    return reshape(g, loop_out,
                   g.op("Constant", value_t=torch.LongTensor(output_sizes)))
def tensordot(g, input_a, input_b, dims_a, dims_b, out=None):
    if out is not None:
        _unimplemented("Tensordot",
                       "Out parameter is not supported for tensordot.")

    dim_count_a = sym_help._get_tensor_rank(input_a)
    if dim_count_a is None:
        raise RuntimeError(
            'Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.'
        )

    dim_count_b = sym_help._get_tensor_rank(input_b)
    if dim_count_b is None:
        raise RuntimeError(
            'Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.'
        )

    dims_a = [(dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i]
              for i in range(len(dims_a))]
    dims_b = [(dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i]
              for i in range(len(dims_b))]

    left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)]
    left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)]

    new_input_a = permute(g, input_a, left_dims_a + dims_a)
    new_input_b = permute(g, input_b, dims_b + left_dims_b)

    input_shape = g.op("Shape", new_input_a)
    left_sizes_a = sym_help._slice_helper(g,
                                          input_shape,
                                          axes=[0],
                                          starts=[0],
                                          ends=[len(left_dims_a)])
    shape_sizes = [
        left_sizes_a,
        g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))
    ]
    output_a = _reshape_from_tensor(g, new_input_a, shape_sizes)

    input_shape = g.op("Shape", output_a)
    slices = sym_help._slice_helper(g,
                                    input_shape,
                                    axes=[0],
                                    starts=[-1],
                                    ends=[maxsize])
    shape_sizes = [
        g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices
    ]
    output_a = _reshape_from_tensor(g, new_input_a, shape_sizes)

    input_shape = g.op("Shape", new_input_b)
    left_sizes_b = sym_help._slice_helper(g,
                                          input_shape,
                                          axes=[0],
                                          starts=[len(dims_b)],
                                          ends=[maxsize])
    slices = sym_help._slice_helper(g,
                                    input_shape,
                                    axes=[0],
                                    starts=[0],
                                    ends=[len(dims_b)])
    shape_sizes = [
        slices,
        g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))
    ]
    output_b = _reshape_from_tensor(g, new_input_b, shape_sizes)

    input_shape = g.op("Shape", output_b)
    slices = sym_help._slice_helper(g,
                                    input_shape,
                                    axes=[0],
                                    starts=[-1],
                                    ends=[maxsize])
    shape_sizes = [
        g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices
    ]
    output_b = _reshape_from_tensor(g, new_input_b, shape_sizes)

    output = einsum(g, 'ij,jk->ik',
                    g.op("prim::ListConstruct", *[output_a, output_b]))

    shape_sizes = [left_sizes_a, left_sizes_b]
    return _reshape_from_tensor(g, output, shape_sizes)
def unfold(g, input, dimension, size, step):
    const_size = sym_help._maybe_get_const(size, 'i')
    const_step = sym_help._maybe_get_const(step, 'i')
    if not sym_help._is_value(const_size) and not sym_help._is_value(
            const_step):
        from torch.onnx.symbolic_opset9 import unfold as _unfold
        return _unfold(g, input, dimension, const_size, const_step)
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        return g.op("ATen",
                    input,
                    operator_s="unfold",
                    dimension_i=dimension,
                    size_i=size,
                    step_i=step)

    sizedim = sym_help._get_tensor_dim_size(input, dimension)
    if sizedim is not None:
        low_start = g.op("Constant", value_t=torch.tensor(0))
        low_end = g.op("Constant", value_t=torch.tensor(sizedim))
        hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
        low_indices = g.op("Range", low_start, low_end, step)
        hi_indices = g.op("Range", size, hi_end, step)

        low_size = sym_help._size_helper(
            g, low_indices, g.op("Constant", value_t=torch.tensor(0)))
        hi_size = sym_help._size_helper(
            g, hi_indices, g.op("Constant", value_t=torch.tensor(0)))

        ndim = sym_help._get_tensor_rank(input)
        perm = list(range(0, ndim))
        perm.append(perm.pop(dimension))

        unsqueeze_list = []
        loop_condition = g.op("Constant", value_t=torch.tensor(1))
        loop_condition = g.op("Cast", loop_condition, to_i=9)
        loop_len = g.op("Min", low_size, hi_size)
        loop = g.op("Loop", loop_len, loop_condition)

        loop_block = _add_block(loop.node())
        block_input_iter = _add_input_to_block(loop_block)
        cond = _add_input_to_block(loop_block)

        starts = loop_block.op("Gather", low_indices, block_input_iter)
        ends = loop_block.op("Gather", hi_indices, block_input_iter)
        axes = loop_block.op("Constant", value_t=torch.tensor([2]))
        starts = sym_help._unsqueeze_helper(loop_block, starts, [0])
        ends = sym_help._unsqueeze_helper(loop_block, ends, [0])
        stack = loop_block.op("Slice", input, starts, ends, axes)

        unsqueeze = sym_help._unsqueeze_helper(
            loop_block, loop_block.op("Transpose", stack, perm_i=perm),
            [dimension])
        unsqueeze_list.append(unsqueeze)
        concat = loop_block.op("Concat", *unsqueeze_list, axis_i=0)

        cond_out = loop_block.op("Cast", loop_condition, to_i=9)
        _add_output_to_block(loop_block, cond_out)
        _add_output_to_block(loop_block, concat)

        loop_output = loop.node().output()
        perm = [0, 1, 2, 3, 4]
        perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0]
        transpose = g.op("Transpose", loop_output, perm_i=perm)
        squeeze = sym_help._squeeze_helper(g, transpose, [0])

        return squeeze
    else:
        return _unimplemented("Unfold", "input size not accessible")
Exemple #9
0
def repeat_interleave(g, self, repeats, dim=None, output_size=None):
    input = self
    final_dim = dim
    # if dim is None flatten
    # By default, use the flattened input array, and return a flat output array
    if sym_help._is_none(dim):
        input = sym_help._reshape_helper(
            g, self, g.op("Constant", value_t=torch.tensor([-1])))
        dim = 0
    else:
        dim = sym_help._maybe_get_scalar(dim)

    repeats_dim = sym_help._get_tensor_rank(repeats)
    repeats_sizes = sym_help._get_tensor_sizes(repeats)
    input_sizes = sym_help._get_tensor_sizes(input)
    if repeats_dim is None:
        raise RuntimeError(
            "Unsupported: ONNX export of repeat_interleave for unknown "
            "repeats rank.")
    if repeats_sizes is None:
        raise RuntimeError(
            "Unsupported: ONNX export of repeat_interleave for unknown "
            "repeats size.")
    if input_sizes is None:
        raise RuntimeError(
            "Unsupported: ONNX export of repeat_interleave for unknown "
            "input size.")
    # Handle cases where dim is negative
    if dim < 0:
        dim += len(input_sizes)

    output_sizes = input_sizes.copy()
    for idx, input_size in enumerate(input_sizes):
        if input_size is None:
            output_sizes[idx], input_sizes[idx] = 0, -1
    print(output_sizes, input_sizes)

    cond_dynamic_repeats = (repeats_dim == 1 and repeats_sizes[0] is None)
    # If input size is dynamic or repeats vector is dynamic
    if output_sizes[dim] == 0 or cond_dynamic_repeats:
        reps = sym_help._size_helper(g, input, dim)
        reps = unsqueeze(g, reps, 0)
        # Check if repeats vector is a single integer value
        # or a single dimension tensor with non-dynamic values
        if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
            if not sym_help._is_tensor(repeats):
                repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
            repeats = g.op("Expand", repeats, reps)
        # Check if repeats is dynamic
        # As repeats is dynamic, we use a where node as a substitute for the if statement
        # If repests_dim = 1, expand repeats otherwise use original tensor
        elif cond_dynamic_repeats:
            repeat_dim = sym_help._size_helper(
                g, repeats, g.op("Constant", value_t=torch.LongTensor([0])))
            repeat_cond = g.op("Equal", repeat_dim,
                               g.op("Constant", value_t=torch.LongTensor([1])))
            repeats = where(g, repeat_cond, g.op("Expand", repeats, reps),
                            repeats)
    # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
    # provided along one of the dynamic axes provided. A simple example would be
    # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
    # Now, repeat interleaving can be performed in pytorch when the value of * matches
    # with the number of elements in repeat, for example if * -> 2, number of repeats
    # should be 2 as well.
    else:
        return torch.onnx.symbolic_opset9.repeat_interleave(
            g, self, repeats, final_dim)

    reps_like = g.op("ConstantOfShape",
                     g.op("Shape", repeats),
                     value_t=torch.tensor([1], dtype=torch.long))
    r_splits = split(g, repeats, reps_like, 0)
    i_splits = split(g, input, reps_like, dim)

    output_sizes[dim], input_sizes[dim] = -1, 1

    # Create a loop to iterate over each value along the dimension
    # and perform individual interleaving using the repeats tensor
    # Loop is of the following pattern
    # input (trip_count, cond)
    #   int trip_count = ...;
    #   bool cond = ...;
    #   for (int i=0; i < trip_count && cond; ++i) {
    #     cond = ...;
    #   }

    # Loop conditions
    loop_condition = g.op("Constant", value_t=torch.tensor(1))
    loop_condition = g.op("Cast", loop_condition, to_i=9)
    loop_len = reps

    # Create an empty sequence to store final expansions
    final_splits = g.op("SequenceEmpty")
    loop = g.op("Loop", loop_len, loop_condition, final_splits)

    # Loop inputs
    loop_block = _add_block(loop.node())
    block_input_iter = _add_input_to_block(loop_block)
    cond = _add_input_to_block(loop_block)
    final_splits = _add_input_to_block(loop_block)

    r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
    i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)

    i_split = unsqueeze(loop_block, i_split, dim + 1)
    r_concat = [
        loop_block.op("Constant",
                      value_t=torch.LongTensor(input_sizes[:dim + 1])),
        r_split,
        loop_block.op("Constant",
                      value_t=torch.LongTensor(input_sizes[dim + 1:]))
    ]
    r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
    i_split = expand(loop_block, i_split, r_concat, None)
    i_split = sym_help._reshape_helper(
        loop_block, i_split,
        g.op("Constant", value_t=torch.LongTensor(output_sizes)))
    final_splits = loop_block.op("SequenceInsert", final_splits, i_split)

    # Loop outputs
    cond_out = loop_block.op("Cast", loop_condition, to_i=9)
    _add_output_to_block(loop_block, cond_out)
    _add_output_to_block(loop_block, final_splits)

    loop_out = loop.node().output()
    loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
    return loop_out
Exemple #10
0
def pixel_unshuffle(g, self, downscale_factor):
    rank = sym_help._get_tensor_rank(self)
    if rank is not None and rank != 4:
        return _unimplemented("pixel_unshuffle", "only support 4d input")
    return g.op("SpaceToDepth", self, blocksize_i=downscale_factor)
Exemple #11
0
def tensor_split(g, self, indices_or_sections, dim, _outputs=None):
    axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
    axis = opset11.unsqueeze(g, axis, 0)
    const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long))

    if symbolic_helper._is_split_static(indices_or_sections, _outputs):
        split_val = symbolic_helper._node_get(indices_or_sections.node(), "value")

        if split_val.dim() > 0:
            start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
            res = []
            assert _outputs is not None
            for i in range(_outputs - 1):
                end = g.op(
                    "Gather",
                    indices_or_sections,
                    g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)),
                    axis_i=0,
                )
                res.append(g.op("Slice", self, start, end, axis))
                start = end

            end = symbolic_helper._size_helper(g, self, axis)
            res.append(g.op("Slice", self, start, end, axis))
            return res

        split_size = symbolic_helper._get_const(
            indices_or_sections, "i", "indices_or_sections"
        )

        size = symbolic_helper._get_tensor_dim_size(self, dim)
        if size is None:
            if _outputs is not None:
                size = split_size * _outputs
            else:
                raise errors.SymbolicValueError(
                    "Unknown dimension size not supported", self
                )

        min_split_size = size // split_size
        num_splits_one_extra = size % split_size

        splits = num_splits_one_extra * [min_split_size + 1]
        leftover = (split_size - num_splits_one_extra) * [min_split_size]

        splits = g.op(
            "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long)
        )
        return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)

    if (
        symbolic_helper._is_tensor(indices_or_sections)
        and symbolic_helper._get_tensor_rank(indices_or_sections) == 1
    ):
        loop_len = symbolic_helper._size_helper(
            g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0))
        )
        loop_len = opset11.unsqueeze(g, loop_len, 0)
        loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL)

        # To make the first slice in the below loop work,
        # we pad a zero to the first position so that it will be the initial start of slice.
        padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
        indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0)

        final_splits = g.op("SequenceEmpty")
        loop = g.op("Loop", loop_len, loop_condition, final_splits)

        # Loop inputs
        loop_block = utils._add_block(loop.node())
        block_input_iter = utils._add_input_to_block(loop_block)
        cond = utils._add_input_to_block(loop_block)
        final_splits = utils._add_input_to_block(loop_block)

        start = loop_block.op("Gather", indices_or_sections, block_input_iter, axis_i=0)
        end = loop_block.op(
            "Gather",
            indices_or_sections,
            loop_block.op("Add", block_input_iter, const_1),
            axis_i=0,
        )

        slice = loop_block.op("Slice", self, start, end, axis)
        final_splits = loop_block.op("SequenceInsert", final_splits, slice)

        # Loop outputs
        cond_out = loop_block.op("Identity", loop_condition)
        utils._add_output_to_block(loop_block, cond_out)
        utils._add_output_to_block(loop_block, final_splits)

        loop_out = loop.node().output()
        start = g.op(
            "Gather",
            indices_or_sections,
            g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)),
            axis_i=0,
        )
        start = opset11.unsqueeze(g, start, 0)
        end = symbolic_helper._size_helper(g, self, axis)

        last_slice = g.op("Slice", self, start, end, axis)

        return g.op("SequenceInsert", loop_out, last_slice)

    else:  # scalar tensor
        dim_size = symbolic_helper._size_helper(g, self, axis)
        min_split_size = g.op("Div", dim_size, indices_or_sections)
        min_split_size_plus_1 = g.op(
            "Add",
            min_split_size,
            const_1,
        )
        num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections)
        splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra)
        leftover = g.op(
            "Tile",
            min_split_size,
            g.op(
                "Sub",
                opset11.unsqueeze(g, indices_or_sections, 0),
                num_splits_one_extra,
            ),
        )

        splits = g.op("Concat", splits, leftover, axis_i=0)
        if _outputs is None:
            return g.op("SplitToSequence", self, splits, axis_i=dim)
        return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
def index_put(g, self, indices_list_value, values, accumulate=False):
    indices_list = sym_help._unpack_list(indices_list_value)
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
            sym_help._unsqueeze_helper(
                g, expand(g, ind, broadcast_index_shape, None), [-1])
            for ind in indices_list
        ]
        index = g.op("Concat", *indices_list, axis_i=-1)
    else:
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains boolean inputs
        #
        # index_put -> masked_fill
        #
        # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %6 : None = prim::Constant()
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %8 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::ne(%mask, %some_const)
        #   %26 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]()
        #   %27 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %11 : Device = prim::Constant[value="cpu"]()
        #   %12 : None = prim::Constant()
        #   %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %15 : None = prim::Constant()
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %22 : int[] = prim::Constant[value=[-1]]()
        #   %23 : Tensor = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        #
        # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %3 : Tensor = onnx::Equal(%0, %some_const)
        #   %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3)
        #   %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4)
        #   %19 : Tensor = onnx::Cast[to=9](%12)
        #   %20 : Tensor = onnx::Constant[value={1}]()
        #   %21 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = onnx::Where(%19, %20, %0)
        #   return (%21)
        #
        # index_put -> masked_scatter
        #
        # before graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=1, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %6 : None = prim::Constant()
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %34 : Long(requires_grad=0, device=cpu) = prim::Constant[value={11}]()
        #   %35 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %18 : Device = prim::Constant[value="cpu"]()
        #   %19 : None = prim::Constant()
        #   %36 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %22 : None = prim::Constant()
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Tensor = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        #
        # after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
        #       %some_const : Float(requires_grad=0, device=cpu)):
        #   %3 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #               = onnx::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %4 : Tensor = onnx::Equal(%0, %some_const)
        #   %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4)
        #   %13 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%5)
        #   %19 : Tensor = onnx::Shape(%0)
        #   %20 : Tensor = onnx::Expand(%13, %19)
        #   %21 : Tensor = onnx::NonZero(%20)
        #   %22 : Tensor = onnx::Transpose[perm=[1, 0]](%21)
        #   %23 : Tensor = onnx::Constant[value={-1}]()
        #   %24 : Tensor = onnx::Reshape(%3, %23)
        #   %25 : Tensor = onnx::Shape(%22)
        #   %27 : Tensor = onnx::Constant[value={0}]()
        #   %28 : Tensor = onnx::Gather[axis=0](%25, %27)
        #   %29 : Tensor = onnx::Constant[value={0}]()
        #   %30 : Tensor = onnx::Unsqueeze[axes=[0]](%29)
        #   %31 : Tensor = onnx::Unsqueeze[axes=[0]](%28)
        #   %32 : Tensor = onnx::Constant[value={0}]()
        #   %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32)
        #   %34 : Tensor = onnx::Slice(%24, %30, %31, %33)
        #   %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = onnx::ScatterND(%0, %22, %34)
        #   return (%35)

        bool_inp = list(index.node().inputs())[0]
        if bool_inp.type() is not None and bool_inp.type().scalarType(
        ) == 'Bool':
            rank = sym_help._get_tensor_rank(values)
            if rank is not None and rank == 0:
                from torch.onnx.symbolic_opset9 import masked_fill
                return masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = sym_help._unsqueeze_helper(g, index, [-1])
    sub_data_shape = sym_help._slice_helper(g,
                                            g.op("Shape", self),
                                            axes=[0],
                                            starts=[len(indices_list)],
                                            ends=[maxsize])
    values_shape = g.op("Concat",
                        broadcast_index_shape,
                        sub_data_shape,
                        axis_i=0)
    values = g.op("Reshape", values, values_shape)

    if accumulate:
        dtype = self.type().scalarType()
        dtype = sym_help.scalar_type_to_onnx.index(
            sym_help.cast_pytorch_to_onnx[dtype])
        dtype = sym_help.scalar_type_to_pytorch_type[dtype]
        zeros = g.op("ConstantOfShape",
                     g.op("Shape", self),
                     value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
    else:
        result = g.op("ScatterND", self, index, values)

    return result
Exemple #13
0
def unfold(g, input, dimension, size, step):
    const_size = symbolic_helper._maybe_get_const(size, "i")
    const_step = symbolic_helper._maybe_get_const(step, "i")
    if not symbolic_helper._is_value(
            const_size) and not symbolic_helper._is_value(const_step):
        return opset9.unfold(g, input, dimension, const_size, const_step)
    if symbolic_helper.is_caffe2_aten_fallback():
        return g.at("unfold",
                    input,
                    dimension_i=dimension,
                    size_i=size,
                    step_i=step)

    sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
    if sizedim is not None:
        low_start = g.op("Constant", value_t=torch.tensor(0))
        low_end = g.op("Constant", value_t=torch.tensor(sizedim))
        hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
        low_indices = g.op("Range", low_start, low_end, step)
        hi_indices = g.op("Range", size, hi_end, step)

        low_size = symbolic_helper._size_helper(
            g, low_indices, g.op("Constant", value_t=torch.tensor(0)))
        hi_size = symbolic_helper._size_helper(
            g, hi_indices, g.op("Constant", value_t=torch.tensor(0)))

        ndim = symbolic_helper._get_tensor_rank(input)
        perm = list(range(0, ndim))
        perm.append(perm.pop(dimension))

        unsqueeze_list = []
        loop_condition = g.op("Constant", value_t=torch.tensor(1))
        loop_condition = g.op("Cast", loop_condition, to_i=9)
        loop_len = g.op("Min", low_size, hi_size)
        loop = g.op("Loop", loop_len, loop_condition)

        loop_block = utils._add_block(loop.node())
        block_input_iter = utils._add_input_to_block(loop_block)
        cond = utils._add_input_to_block(loop_block)

        starts = loop_block.op("Gather", low_indices, block_input_iter)
        ends = loop_block.op("Gather", hi_indices, block_input_iter)
        axes = loop_block.op("Constant", value_t=torch.tensor([2]))
        starts = symbolic_helper._unsqueeze_helper(loop_block, starts, [0])
        ends = symbolic_helper._unsqueeze_helper(loop_block, ends, [0])
        stack = loop_block.op("Slice", input, starts, ends, axes)

        unsqueeze = symbolic_helper._unsqueeze_helper(
            loop_block, loop_block.op("Transpose", stack, perm_i=perm),
            [dimension])
        unsqueeze_list.append(unsqueeze)
        concat = loop_block.op("Concat", *unsqueeze_list, axis_i=0)

        cond_out = loop_block.op("Cast", loop_condition, to_i=9)
        utils._add_output_to_block(loop_block, cond_out)
        utils._add_output_to_block(loop_block, concat)

        loop_output = loop.node().output()
        perm = [0, 1, 2, 3, 4]
        perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0]
        transpose = g.op("Transpose", loop_output, perm_i=perm)
        squeeze = symbolic_helper._squeeze_helper(g, transpose, [0])

        return squeeze
    else:
        return symbolic_helper._unimplemented("Unfold",
                                              "input size not accessible")
Exemple #14
0
def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_scale_factor):
    mode = sym_help._maybe_get_const(mode, 's')
    if 'linear' in mode:
        mode = 'linear'
    if 'cubic' in mode:
        mode = 'cubic'
    align_corners = sym_help._maybe_get_const(align_corners, 'b')
    align_corners = False if not isinstance(align_corners, bool) else align_corners
    coordinate_transformation_mode = "asymmetric" if mode == "nearest" \
        else "align_corners" if align_corners else "pytorch_half_pixel"
    # roi only takes effect with coordinate_transformation_mode="tf_crop_and_resize"
    roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))

    if not sym_help._is_none(size) :
        input_size = g.op("Shape", input)
        input_size = sym_help._slice_helper(g, input_size, axes=[0], ends=[2], starts=[0])
        # in some cases size is not a packed list but size is a scalar
        # We need to also verify that (sym_help._maybe_get_const(size, 't').dim() == 0)
        # but this information is not always available. Try to get the dim,
        # and if not assume that it is not a scalar.
        try:
            is_scalar = not sym_help._is_packed_list(size) and ((sym_help._maybe_get_const(size, 't').dim() == 0))
        except AttributeError:
            is_scalar = not sym_help._is_packed_list(size)
            if not is_scalar:
                warnings.warn("Cannot verify if the output_size is a scalar "
                              "while exporting interpolate. Assuming that it is not a scalar.")

        if is_scalar:
            rank = sym_help._get_tensor_rank(input)
            if rank is None:
                return sym_help._unimplemented("interpolate (with a scalar output_size)",
                                               "missing input shape (try giving an array of output_size values)")
            size = unsqueeze(g, size, 0)
            size = [size for i in range(rank - 2)]
            size = g.op("Concat", *size, axis_i=0)
        size = g.op("Cast", size, to_i=sym_help.cast_pytorch_to_onnx['Long'])
        size = g.op("Concat", input_size, size, axis_i=0)
        scales = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
        return g.op("Resize",
                    input,
                    roi,
                    scales,
                    size,
                    coordinate_transformation_mode_s=coordinate_transformation_mode,
                    cubic_coeff_a_f=-0.75,  # only valid when mode="cubic"
                    mode_s=mode,  # nearest, linear, or cubic
                    nearest_mode_s="floor")
    else:  # if not sym_help._is_none(scales)
        rank = sym_help._get_tensor_rank(input)
        if rank is None:
            return sym_help._unimplemented("interpolate (with scales)", "missing input shape")
        scales = sym_help._interpolate_get_scales(g, scale_factor, rank)
        return g.op("Resize",
                    input,
                    roi,
                    scales,
                    coordinate_transformation_mode_s=coordinate_transformation_mode,
                    cubic_coeff_a_f=-0.75,  # only valid when mode="cubic"
                    mode_s=mode,  # nearest, linear, or cubic
                    nearest_mode_s="floor")  # only valid when mode="nearest"
def replication_pad(g, input, padding):
    mode = "edge"
    paddings = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input),
                                      padding)
    return g.op("Pad", input, paddings, mode_s=mode)
Exemple #16
0
def diagonal(g, self, offset, dim1, dim2):
    dim1_size = opset9.size(
        g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1]))
    )
    dim2_size = opset9.size(
        g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2]))
    )

    # Create appropriate mask
    mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0)
    mask = opset9.zeros(g, mask_shape, None, None, None)
    mask = g.op("EyeLike", mask, k_i=offset)

    # dim1 and dim2 appended as a dimension at the end of the shape
    rank = symbolic_helper._get_tensor_rank(self)
    if rank is not None:
        axes = list(range(rank))
        axes.remove(dim1)
        axes.remove(dim2)
        self = g.op("Transpose", self, perm_i=axes + [dim1, dim2])
    else:
        return symbolic_helper._unimplemented("diagonal", "unknown input rank")

    # Multiply input and mask to calculate values along diagonal
    # The mask consists of one values where diagonal values are to be calculated
    # For example:
    # [[1.1, 1.2, 1.3],   *    [[1, 0, 0]   =   [[1.1, 0, 0],
    #  [2.1, 2.2, 2.3],         [0, 1, 0]        [0, 2.2, 0],
    #  [3.1, 3.2, 3.3]]         [0, 0, 1]]       [0, 0, 3.3]]
    result = g.op("Mul", self, mask)
    result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0)

    # Calculate gather indices based on offset and dims
    # If offset is greater than zero, set offset to zero as this aids in
    # calculation of selection window
    offset_op = g.op("Constant", value_t=torch.LongTensor([offset]))
    if offset >= 0:
        diag_size = g.op(
            "Max",
            g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)),
            g.op("Constant", value_t=torch.LongTensor([0])),
        )
        offset = 0
    else:
        diag_size = g.op(
            "Max",
            g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size),
            g.op("Constant", value_t=torch.LongTensor([0])),
        )
    diag_size = g.op("Concat", diag_size, axis_i=0)

    # Calculate which diagonal values to select
    # For example, in cases with offsets:
    # [[0, 1.1, 0]
    #  [0, 0, 2.2]]
    # we need to select the last two columns, so we create a tensor
    # with all columns that are to be selected
    # So in this example, it is [1, 2]
    select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None)
    select_window = g.op(
        "CumSum",
        select_window_ones_fill,
        g.op("Constant", value_t=torch.LongTensor([0])),
    )
    select_window = g.op(
        "Add",
        select_window,
        g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])),
    )

    gather_shape = [
        opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis])))
        for axis in list(range(rank))[:-2]
    ]
    gather_shape.append(diag_size)
    gather_shape = g.op("Concat", *gather_shape, axis_i=0)
    gather_indices = opset9.zeros(g, gather_shape, 4, None, None)

    # There might be cases where offset value is greater than number of rows/columns
    # and might cause the diagonal to overrun and as a result of this, diag_size would be zero.
    # For example, if
    #       offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows)
    #       diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above
    # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0
    # In cases without diagonal overrun, we select the appropriate rows/columns along which we
    # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has
    # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially
    # returning an empty tensor
    overrun_cond = g.op(
        "Not",
        g.op(
            "Equal",
            diag_size,
            g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)),
        ),
    )
    if_op = g.op("If", overrun_cond)
    if_node = if_op.node()

    if_block = utils._add_block(if_node)
    gather_indices_if_block = if_block.op("Add", gather_indices, select_window)
    gather_indices_if_block = symbolic_helper._unsqueeze_helper(
        if_block, gather_indices_if_block, [rank - 1]
    )
    final_non_overrun_ = if_block.op(
        "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2
    )
    utils._add_output_to_block(if_block, final_non_overrun_)

    else_block = utils._add_block(if_node)
    final_overrun_ = opset9.zeros(else_block, gather_shape, 6, None, None)
    utils._add_output_to_block(else_block, final_overrun_)
    return if_op
def index_put(g, self, indices_list_value, values, accumulate=False):
    if sym_help._is_packed_list(indices_list_value):
        indices_list = sym_help._unpack_list(indices_list_value)
    else:
        indices_list = [indices_list_value]
    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
        args = [self] + indices_list + [values, accumulate]
        return g.op("ATen", *args, operator_s='index_put')

    from torch.onnx.symbolic_opset9 import add, expand
    accumulate = sym_help._parse_arg(accumulate, 'b')

    if len(indices_list) == 0:
        return values

    index = indices_list[0]

    if len(indices_list) > 1:
        for ind in indices_list[1:]:
            index = add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
            sym_help._unsqueeze_helper(
                g, expand(g, ind, broadcast_index_shape, None), [-1])
            for ind in indices_list
        ]
        index = g.op("Concat", *indices_list, axis_i=-1)
    else:
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains boolean inputs
        #
        # index_put -> masked_fill
        #   * input index contains single tensor of Bool type (e.g.: %24 <- %23).
        #   * input value contains single element (e.g.: %18).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        #
        #
        # index_put -> masked_scatter
        #   * input index contains single tensor of Bool type (e.g.: %32 <- %31).
        #   * input value contains multiple elements (e.g.: %28).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        bool_inp = index
        if bool_inp.type() is not None and bool_inp.type().scalarType(
        ) == 'Bool':
            rank = sym_help._get_tensor_rank(values)
            if rank is not None and rank == 0:
                from torch.onnx.symbolic_opset9 import masked_fill
                return masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = sym_help._unsqueeze_helper(g, index, [-1])
    sub_data_shape = sym_help._slice_helper(g,
                                            g.op("Shape", self),
                                            axes=[0],
                                            starts=[len(indices_list)],
                                            ends=[maxsize])
    values_shape = g.op("Concat",
                        broadcast_index_shape,
                        sub_data_shape,
                        axis_i=0)
    # Check if values is a singular value and expand accordingly
    rank = sym_help._get_tensor_rank(values)
    if rank is not None and rank == 0:
        values = expand(g, values, values_shape, None)
    values = g.op("Reshape", values, values_shape)

    dtype = self.type().scalarType()
    if dtype is not None and dtype != values.type().scalarType():
        values = g.op("Cast",
                      values,
                      to_i=sym_help.cast_pytorch_to_onnx[dtype])
    dtype = sym_help.scalar_type_to_onnx.index(
        sym_help.cast_pytorch_to_onnx[dtype])
    dtype = sym_help.scalar_type_to_pytorch_type[dtype]

    if accumulate:
        zeros = g.op("ConstantOfShape",
                     g.op("Shape", self),
                     value_t=torch.tensor([0], dtype=dtype))
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
    else:
        result = g.op("ScatterND", self, index, values)

    return result
Exemple #18
0
def index_put(g, self, indices_list_value, values, accumulate=False):
    if symbolic_helper._is_packed_list(indices_list_value):
        indices_list = symbolic_helper._unpack_list(indices_list_value)
    else:
        indices_list = [indices_list_value]
    if symbolic_helper.is_caffe2_aten_fallback():
        args = [self] + indices_list + [values, accumulate]
        return g.at("index_put", *args)

    accumulate = symbolic_helper._parse_arg(accumulate, "b")

    if len(indices_list) == 0:
        return values

    if len(indices_list) > 1:
        for idx_ in range(len(indices_list)):
            if indices_list[idx_].type().scalarType() == "Bool":  # type: ignore[attr-defined]
                # TODO(justinchuby): Remove type ignore after #81112 is checked in.
                indices_list[idx_] = g.op("NonZero", indices_list[idx_])
        index = indices_list[0]

        for ind in indices_list[1:]:
            index = opset9.add(g, index, ind)
        broadcast_index_shape = g.op("Shape", index)
        indices_list = [
            symbolic_helper._unsqueeze_helper(
                g, opset9.expand(g, ind, broadcast_index_shape, None), [-1]
            )
            for ind in indices_list
        ]
        index = g.op("Concat", *indices_list, axis_i=-1)
    else:
        # Replace index_put node with masked_scatter or masked_fill
        # when inputs to the index_put node contains a single boolean input.
        #
        # index_put -> masked_fill
        #   * input index contains single tensor of Bool type (e.g.: %24 <- %23).
        #   * input value contains single element (e.g.: %18).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #               aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
        #   %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
        #   %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
        #   %24 : Tensor?[] = prim::ListConstruct(%23)
        #   %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
        #                aten::index_put(%mask, %24, %18, %30)
        #   return (%25)
        #
        #
        # index_put -> masked_scatter
        #   * input index contains single tensor of Bool type (e.g.: %32 <- %31).
        #   * input value contains multiple elements (e.g.: %28).
        #
        # Torch IR
        #   %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
        #   %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
        #                = prim::Constant[value= 1  1  1  1  1  1  1  1 [ CPUFloatType{8} ]]()
        #   %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::ne(%mask, %some_const)
        #   %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #                = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
        #   %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
        #   %30 : int[] = prim::Constant[value=[-1]]()
        #   %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
        #   %32 : Tensor?[] = prim::ListConstruct(%31)
        #   %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
        #               = aten::index_put(%mask, %32, %28, %38)
        #   return (%33)
        index = indices_list[0]
        bool_inp = index
        if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool":  # type: ignore[attr-defined]
            # TODO(justinchuby): Remove type ignore after #81112 is checked in.
            rank = symbolic_helper._get_tensor_rank(values)
            if rank is not None and rank == 0:
                return opset9.masked_fill(g, self, bool_inp, values)
            return masked_scatter(g, self, bool_inp, values)
        broadcast_index_shape = g.op("Shape", index)
        index = symbolic_helper._unsqueeze_helper(g, index, [-1])
    sub_data_shape = symbolic_helper._slice_helper(
        g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize]
    )
    values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
    # Check if values is a singular value and expand accordingly
    rank = symbolic_helper._get_tensor_rank(values)
    if rank is not None and rank == 0:
        values = opset9.expand(g, values, values_shape, None)
    values = symbolic_helper._reshape_helper(g, values, values_shape)

    dtype = self.type().scalarType()
    if dtype is not None and dtype != values.type().scalarType():
        values = g.op("Cast", values, to_i=symbolic_helper.cast_pytorch_to_onnx[dtype])
    dtype = symbolic_helper.scalar_type_to_onnx.index(
        symbolic_helper.cast_pytorch_to_onnx[dtype]
    )
    dtype = symbolic_helper.scalar_type_to_pytorch_type[dtype]

    if accumulate:
        zeros = g.op(
            "ConstantOfShape",
            g.op("Shape", self),
            value_t=torch.tensor([0], dtype=dtype),
        )
        result = g.op("ScatterND", zeros, index, values)
        result = add(g, self, result)
    else:
        result = g.op("ScatterND", self, index, values)

    return result