def verify_inferred_shape(graph):
    # Check every node in graph has type properly assigned.
    for n in graph.nodes():
        for out in n.outputs():
            if not _is_tensor_list(out) and not _is_tensor(
                    out) and not _is_none(out):
                raise RuntimeError(
                    "Output of node is neither type Tensor nor type list of Tensor: ",
                    out)
            if _is_tensor(out) and out.type().scalarType() is None:
                raise RuntimeError(
                    "Output of node does not have type assigned", out)
            if _is_tensor(out) and out.type().dim() is None:
                raise RuntimeError(
                    "Output of node does not have shape assigned", out)
def repeat_interleave(g, self, repeats, dim=None):
    from torch.onnx.symbolic_opset9 import reshape
    input = self
    final_dim = dim
    # if dim is None flatten
    # By default, use the flattened input array, and return a flat output array
    if sym_help._is_none(dim):
        input = reshape(g, self, g.op("Constant", value_t=torch.tensor([-1])))
        dim = 0
    else:
        dim = sym_help._maybe_get_scalar(dim)

    repeats_dim = sym_help._get_tensor_rank(repeats)
    repeats_sizes = sym_help._get_tensor_sizes(repeats)
    input_sizes = sym_help._get_tensor_sizes(input)
    if repeats_dim is None:
        raise RuntimeError(
            'Unsupported: ONNX export of repeat_interleave for unknown '
            'repeats rank.')
    if repeats_sizes is None:
        raise RuntimeError(
            'Unsupported: ONNX export of repeat_interleave for unknown '
            'repeats size.')
    if input_sizes is None:
        raise RuntimeError(
            'Unsupported: ONNX export of repeat_interleave for unknown '
            'input size.')
    # Handle cases where dim is negative
    if dim < 0:
        dim += len(input_sizes)

    output_sizes = input_sizes.copy()
    perm_i = [0]
    for idx, input_size in enumerate(input_sizes):
        perm_i.append(idx + 1)
        if input_size is None:
            output_sizes[idx], input_sizes[idx] = 0, -1
    perm_i[0], perm_i[dim] = perm_i[dim], perm_i[0]

    # Cases when repeats is a single value tensor and dim has unknown input size
    if (repeats_dim == 0 or
        (repeats_dim == 1
         and repeats_sizes[0] == 1)) and output_sizes[dim] == 0:
        if not sym_help._is_tensor(repeats):
            repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
        reps = sym_help._size_helper(g, input, dim)
        reps = unsqueeze(g, reps, 0)
        repeats = g.op("Expand", repeats, reps)
    # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
    # provided along one of the dynamic axes provided. A simple example would be
    # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
    # Now, repeat interleaving can be performed in pytorch when the value of * matches
    # with the number of elements in repeat, for example if * -> 2, number of repeats
    # should be 2 as well.
    else:
        return torch.onnx.symbolic_opset9.repeat_interleave(
            g, self, repeats, final_dim)

    reps_like = g.op("ConstantOfShape",
                     g.op("Shape", repeats),
                     value_t=torch.tensor([1], dtype=torch.long))
    r_splits = split(g, repeats, reps_like, 0)
    i_splits = split(g, input, reps_like, dim)

    output_sizes[dim], input_sizes[dim] = -1, 1

    # Create a loop to iterate over each value along the dimension
    # and perform individual interleaving using the repeats tensor
    # Loop is of the following pattern
    # input (trip_count, cond)
    #   int trip_count = ...;
    #   bool cond = ...;
    #   for (int i=0; i < trip_count && cond; ++i) {
    #     cond = ...;
    #   }

    # Loop conditions
    loop_condition = g.op("Constant", value_t=torch.tensor(1))
    loop_condition = g.op("Cast", loop_condition, to_i=9)
    loop_len = reps
    loop = g.op("Loop", loop_len, loop_condition)

    # Loop inputs
    loop_block = _add_block(loop.node())
    block_input_iter = _add_input_to_block(loop_block)
    cond = _add_input_to_block(loop_block)

    r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
    i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)

    i_split = unsqueeze(loop_block, i_split, dim + 1)
    r_concat = [
        loop_block.op("Constant",
                      value_t=torch.LongTensor(input_sizes[:dim + 1])),
        r_split,
        loop_block.op("Constant",
                      value_t=torch.LongTensor(input_sizes[dim + 1:]))
    ]
    r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
    i_split = expand(loop_block, i_split, r_concat, None)
    i_split = reshape(loop_block, i_split,
                      g.op("Constant", value_t=torch.LongTensor(output_sizes)))

    # Loop outputs
    cond_out = loop_block.op("Cast", loop_condition, to_i=9)
    _add_output_to_block(loop_block, cond_out)
    _add_output_to_block(loop_block, i_split)
    loop_out = loop.node().output()

    # In this loop, the outputs are scan outputs and are concatenated along
    # the zero'th dimension (by default). In order to avoid this and concatenate
    # along the dimension provided, some post-processing is required
    loop_out = g.op("Transpose", loop_out, perm_i=perm_i)
    return reshape(g, loop_out,
                   g.op("Constant", value_t=torch.LongTensor(output_sizes)))
Exemplo n.º 3
0
def repeat_interleave(g, self, repeats, dim=None, output_size=None):
    input = self
    final_dim = dim
    # if dim is None flatten
    # By default, use the flattened input array, and return a flat output array
    if sym_help._is_none(dim):
        input = sym_help._reshape_helper(
            g, self, g.op("Constant", value_t=torch.tensor([-1])))
        dim = 0
    else:
        dim = sym_help._maybe_get_scalar(dim)

    repeats_dim = sym_help._get_tensor_rank(repeats)
    repeats_sizes = sym_help._get_tensor_sizes(repeats)
    input_sizes = sym_help._get_tensor_sizes(input)
    if repeats_dim is None:
        raise RuntimeError(
            "Unsupported: ONNX export of repeat_interleave for unknown "
            "repeats rank.")
    if repeats_sizes is None:
        raise RuntimeError(
            "Unsupported: ONNX export of repeat_interleave for unknown "
            "repeats size.")
    if input_sizes is None:
        raise RuntimeError(
            "Unsupported: ONNX export of repeat_interleave for unknown "
            "input size.")
    # Handle cases where dim is negative
    if dim < 0:
        dim += len(input_sizes)

    output_sizes = input_sizes.copy()
    for idx, input_size in enumerate(input_sizes):
        if input_size is None:
            output_sizes[idx], input_sizes[idx] = 0, -1
    print(output_sizes, input_sizes)

    cond_dynamic_repeats = (repeats_dim == 1 and repeats_sizes[0] is None)
    # If input size is dynamic or repeats vector is dynamic
    if output_sizes[dim] == 0 or cond_dynamic_repeats:
        reps = sym_help._size_helper(g, input, dim)
        reps = unsqueeze(g, reps, 0)
        # Check if repeats vector is a single integer value
        # or a single dimension tensor with non-dynamic values
        if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
            if not sym_help._is_tensor(repeats):
                repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
            repeats = g.op("Expand", repeats, reps)
        # Check if repeats is dynamic
        # As repeats is dynamic, we use a where node as a substitute for the if statement
        # If repests_dim = 1, expand repeats otherwise use original tensor
        elif cond_dynamic_repeats:
            repeat_dim = sym_help._size_helper(
                g, repeats, g.op("Constant", value_t=torch.LongTensor([0])))
            repeat_cond = g.op("Equal", repeat_dim,
                               g.op("Constant", value_t=torch.LongTensor([1])))
            repeats = where(g, repeat_cond, g.op("Expand", repeats, reps),
                            repeats)
    # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
    # provided along one of the dynamic axes provided. A simple example would be
    # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
    # Now, repeat interleaving can be performed in pytorch when the value of * matches
    # with the number of elements in repeat, for example if * -> 2, number of repeats
    # should be 2 as well.
    else:
        return torch.onnx.symbolic_opset9.repeat_interleave(
            g, self, repeats, final_dim)

    reps_like = g.op("ConstantOfShape",
                     g.op("Shape", repeats),
                     value_t=torch.tensor([1], dtype=torch.long))
    r_splits = split(g, repeats, reps_like, 0)
    i_splits = split(g, input, reps_like, dim)

    output_sizes[dim], input_sizes[dim] = -1, 1

    # Create a loop to iterate over each value along the dimension
    # and perform individual interleaving using the repeats tensor
    # Loop is of the following pattern
    # input (trip_count, cond)
    #   int trip_count = ...;
    #   bool cond = ...;
    #   for (int i=0; i < trip_count && cond; ++i) {
    #     cond = ...;
    #   }

    # Loop conditions
    loop_condition = g.op("Constant", value_t=torch.tensor(1))
    loop_condition = g.op("Cast", loop_condition, to_i=9)
    loop_len = reps

    # Create an empty sequence to store final expansions
    final_splits = g.op("SequenceEmpty")
    loop = g.op("Loop", loop_len, loop_condition, final_splits)

    # Loop inputs
    loop_block = _add_block(loop.node())
    block_input_iter = _add_input_to_block(loop_block)
    cond = _add_input_to_block(loop_block)
    final_splits = _add_input_to_block(loop_block)

    r_split = loop_block.op("SequenceAt", r_splits, block_input_iter)
    i_split = loop_block.op("SequenceAt", i_splits, block_input_iter)

    i_split = unsqueeze(loop_block, i_split, dim + 1)
    r_concat = [
        loop_block.op("Constant",
                      value_t=torch.LongTensor(input_sizes[:dim + 1])),
        r_split,
        loop_block.op("Constant",
                      value_t=torch.LongTensor(input_sizes[dim + 1:]))
    ]
    r_concat = loop_block.op("Concat", *r_concat, axis_i=0)
    i_split = expand(loop_block, i_split, r_concat, None)
    i_split = sym_help._reshape_helper(
        loop_block, i_split,
        g.op("Constant", value_t=torch.LongTensor(output_sizes)))
    final_splits = loop_block.op("SequenceInsert", final_splits, i_split)

    # Loop outputs
    cond_out = loop_block.op("Cast", loop_condition, to_i=9)
    _add_output_to_block(loop_block, cond_out)
    _add_output_to_block(loop_block, final_splits)

    loop_out = loop.node().output()
    loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
    return loop_out
Exemplo n.º 4
0
def tensor_split(g, self, indices_or_sections, dim, _outputs=None):
    axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
    axis = opset11.unsqueeze(g, axis, 0)
    const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long))

    if symbolic_helper._is_split_static(indices_or_sections, _outputs):
        split_val = symbolic_helper._node_get(indices_or_sections.node(), "value")

        if split_val.dim() > 0:
            start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
            res = []
            assert _outputs is not None
            for i in range(_outputs - 1):
                end = g.op(
                    "Gather",
                    indices_or_sections,
                    g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)),
                    axis_i=0,
                )
                res.append(g.op("Slice", self, start, end, axis))
                start = end

            end = symbolic_helper._size_helper(g, self, axis)
            res.append(g.op("Slice", self, start, end, axis))
            return res

        split_size = symbolic_helper._get_const(
            indices_or_sections, "i", "indices_or_sections"
        )

        size = symbolic_helper._get_tensor_dim_size(self, dim)
        if size is None:
            if _outputs is not None:
                size = split_size * _outputs
            else:
                raise errors.SymbolicValueError(
                    "Unknown dimension size not supported", self
                )

        min_split_size = size // split_size
        num_splits_one_extra = size % split_size

        splits = num_splits_one_extra * [min_split_size + 1]
        leftover = (split_size - num_splits_one_extra) * [min_split_size]

        splits = g.op(
            "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long)
        )
        return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)

    if (
        symbolic_helper._is_tensor(indices_or_sections)
        and symbolic_helper._get_tensor_rank(indices_or_sections) == 1
    ):
        loop_len = symbolic_helper._size_helper(
            g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0))
        )
        loop_len = opset11.unsqueeze(g, loop_len, 0)
        loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL)

        # To make the first slice in the below loop work,
        # we pad a zero to the first position so that it will be the initial start of slice.
        padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
        indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0)

        final_splits = g.op("SequenceEmpty")
        loop = g.op("Loop", loop_len, loop_condition, final_splits)

        # Loop inputs
        loop_block = utils._add_block(loop.node())
        block_input_iter = utils._add_input_to_block(loop_block)
        cond = utils._add_input_to_block(loop_block)
        final_splits = utils._add_input_to_block(loop_block)

        start = loop_block.op("Gather", indices_or_sections, block_input_iter, axis_i=0)
        end = loop_block.op(
            "Gather",
            indices_or_sections,
            loop_block.op("Add", block_input_iter, const_1),
            axis_i=0,
        )

        slice = loop_block.op("Slice", self, start, end, axis)
        final_splits = loop_block.op("SequenceInsert", final_splits, slice)

        # Loop outputs
        cond_out = loop_block.op("Identity", loop_condition)
        utils._add_output_to_block(loop_block, cond_out)
        utils._add_output_to_block(loop_block, final_splits)

        loop_out = loop.node().output()
        start = g.op(
            "Gather",
            indices_or_sections,
            g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)),
            axis_i=0,
        )
        start = opset11.unsqueeze(g, start, 0)
        end = symbolic_helper._size_helper(g, self, axis)

        last_slice = g.op("Slice", self, start, end, axis)

        return g.op("SequenceInsert", loop_out, last_slice)

    else:  # scalar tensor
        dim_size = symbolic_helper._size_helper(g, self, axis)
        min_split_size = g.op("Div", dim_size, indices_or_sections)
        min_split_size_plus_1 = g.op(
            "Add",
            min_split_size,
            const_1,
        )
        num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections)
        splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra)
        leftover = g.op(
            "Tile",
            min_split_size,
            g.op(
                "Sub",
                opset11.unsqueeze(g, indices_or_sections, 0),
                num_splits_one_extra,
            ),
        )

        splits = g.op("Concat", splits, leftover, axis_i=0)
        if _outputs is None:
            return g.op("SplitToSequence", self, splits, axis_i=dim)
        return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)