Пример #1
0
def _get_im2col_output_shape(g, input, kernel_h, kernel_w):
    batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0)))
    channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1)))
    channel_unfolded = g.op("Mul", channel_dim,
                            g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w)))

    return g.op("Concat",
                g.op("Unsqueeze", batch_dim, axes_i=[0]),
                g.op("Unsqueeze", channel_unfolded, axes_i=[0]),
                g.op("Constant", value_t=torch.tensor([-1])), axis_i=0)
def im2col(g, input, kernel_size, dilation, padding, stride):
    # Input is always 4-D tensor (N, C, H, W)
    # All other args are int[2]

    input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2)))
    input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3)))

    stride_h, stride_w = stride[0], stride[1]
    padding_h, padding_w = padding[0], padding[1]
    dilation_h, dilation_w = dilation[0], dilation[1]
    kernel_h, kernel_w = kernel_size[0], kernel_size[1]

    blocks_row_indices = _get_im2col_indices_along_dim(g, input_h, kernel_h,
                                                       dilation_h, padding_h,
                                                       stride_h)
    blocks_col_indices = _get_im2col_indices_along_dim(g, input_w, kernel_w,
                                                       dilation_w, padding_w,
                                                       stride_w)

    output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w)
    padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w)

    # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1
    # [[[[1., 2., 3.,],
    #    [4., 5., 6.,],
    #    [7., 8., 9.,]]]]
    # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get:
    # [[[[[1., 2., 3.],
    #     [4., 5., 6.]],
    #    [[4., 5., 6.],
    #     [7., 8., 9.]]]]]
    # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get:
    # [[[[[[1., 2.],
    #      [4., 5.]],
    #     [[2., 3.],
    #      [5., 6]]],
    #    [[[4., 5.],
    #      [7., 8.]],
    #     [[5., 6.],
    #      [8., 9.]]]]]]
    # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get:
    #  [[[1., 2., 4., 5.],
    #    [2., 3., 5., 6.],
    #    [4., 5., 7., 8.],
    #    [5., 6., 8., 9.]]]
    output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2)
    output = g.op("Gather", output, blocks_col_indices, axis_i=4)
    output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5])
    return g.op("Reshape", output, output_shape)
Пример #3
0
    def reverse(x):
        from torch.onnx.symbolic_opset9 import reshape, transpose, size

        y = transpose(g, x, 0, dim)
        shape = g.op("Shape", y)
        y = reshape(g, y, [0, 1, -1])
        n = size(g, y, g.op("Constant", value_t=torch.LongTensor([0])))
        y = g.op("ReverseSequence", y, n, batch_axis_i=1, time_axis_i=0)
        y = reshape(g, y, shape)
        y = transpose(g, y, 0, dim)
        return y
Пример #4
0
def masked_scatter(g, self, mask, source):
    from torch.onnx.symbolic_opset9 import nonzero, expand_as, view, size
    index = nonzero(g, expand_as(g, mask, self))
    # NOTE: source can have more elements than needed.
    # It could also have arbitrary shape.
    # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
    source = view(g, source, torch.LongTensor([-1]))
    source = sym_help._slice_helper(g, source,
                                    axes=torch.LongTensor([0]),
                                    starts=torch.LongTensor([0]),
                                    ends=size(g, index, torch.LongTensor([0])),
                                    dynamic_slice=True)
    return g.op('ScatterND', self, index, source)
Пример #5
0
def _prepare_onnx_paddings(g, input, pad):
    if (
        not symbolic_helper._is_packed_list(pad)
        and symbolic_helper._is_list(pad)
        and symbolic_helper._is_scalar_list(pad)
    ):
        pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1)
    # The desired order of paddings is
    # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
    # n is the dimension of input.
    # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning
    pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0])))
    # Set extension = [0] * (dim * 2 - len(pad))
    rank = symbolic_helper._get_tensor_rank(input)
    if rank is None:
        rank = g.op("Size", g.op("Shape", input))
    else:
        rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64))
    extension = g.op(
        "Sub",
        g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))),
        pad_len,
    )
    # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ]
    # Currently ONNX only supports int64 type for Pad
    pad = g.op("Cast", pad, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"])
    paddings = g.op(
        "Concat",
        pad,
        g.op(
            "ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64)
        ),
        axis_i=0,
    )
    # Reshape and reverse order and collate first beginnings and then ends
    # paddings = [[..., 0, dim_n-1_begin, dim_n_begin],
    #               [..., 0, dim_n-1_end, dim_n_end]]
    # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end]
    paddings = symbolic_helper._reshape_helper(
        g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2]))
    )
    paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0])
    paddings = symbolic_helper._reshape_helper(
        g, paddings, g.op("Constant", value_t=torch.tensor([-1]))
    )
    padding_c = g.op(
        "Cast", paddings, to_i=symbolic_helper.cast_pytorch_to_onnx["Long"]
    )
    return padding_c
Пример #6
0
def masked_scatter(g, self, mask, source):
    index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
    # NOTE: source can have more elements than needed.
    # It could also have arbitrary shape.
    # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
    source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1]))
    source = symbolic_helper._slice_helper(
        g,
        source,
        axes=torch.LongTensor([0]),
        starts=torch.LongTensor([0]),
        ends=opset9.size(g, index, torch.LongTensor([0])),
        dynamic_slice=True,
    )
    return g.op("ScatterND", self, index, source)
def _len(g, self):
    if _is_tensor_list(self) or self.node().kind() == "onnx::SplitToSequence":
        return g.op("SequenceLength", self)
    sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0])))
    return sym_help._squeeze_helper(g, sz_0, [0])
Пример #8
0
def diagonal(g, self, offset, dim1, dim2):
    dim1_size = opset9.size(
        g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1]))
    )
    dim2_size = opset9.size(
        g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2]))
    )

    # Create appropriate mask
    mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0)
    mask = opset9.zeros(g, mask_shape, None, None, None)
    mask = g.op("EyeLike", mask, k_i=offset)

    # dim1 and dim2 appended as a dimension at the end of the shape
    rank = symbolic_helper._get_tensor_rank(self)
    if rank is not None:
        axes = list(range(rank))
        axes.remove(dim1)
        axes.remove(dim2)
        self = g.op("Transpose", self, perm_i=axes + [dim1, dim2])
    else:
        return symbolic_helper._unimplemented("diagonal", "unknown input rank")

    # Multiply input and mask to calculate values along diagonal
    # The mask consists of one values where diagonal values are to be calculated
    # For example:
    # [[1.1, 1.2, 1.3],   *    [[1, 0, 0]   =   [[1.1, 0, 0],
    #  [2.1, 2.2, 2.3],         [0, 1, 0]        [0, 2.2, 0],
    #  [3.1, 3.2, 3.3]]         [0, 0, 1]]       [0, 0, 3.3]]
    result = g.op("Mul", self, mask)
    result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0)

    # Calculate gather indices based on offset and dims
    # If offset is greater than zero, set offset to zero as this aids in
    # calculation of selection window
    offset_op = g.op("Constant", value_t=torch.LongTensor([offset]))
    if offset >= 0:
        diag_size = g.op(
            "Max",
            g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)),
            g.op("Constant", value_t=torch.LongTensor([0])),
        )
        offset = 0
    else:
        diag_size = g.op(
            "Max",
            g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size),
            g.op("Constant", value_t=torch.LongTensor([0])),
        )
    diag_size = g.op("Concat", diag_size, axis_i=0)

    # Calculate which diagonal values to select
    # For example, in cases with offsets:
    # [[0, 1.1, 0]
    #  [0, 0, 2.2]]
    # we need to select the last two columns, so we create a tensor
    # with all columns that are to be selected
    # So in this example, it is [1, 2]
    select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None)
    select_window = g.op(
        "CumSum",
        select_window_ones_fill,
        g.op("Constant", value_t=torch.LongTensor([0])),
    )
    select_window = g.op(
        "Add",
        select_window,
        g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])),
    )

    gather_shape = [
        opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis])))
        for axis in list(range(rank))[:-2]
    ]
    gather_shape.append(diag_size)
    gather_shape = g.op("Concat", *gather_shape, axis_i=0)
    gather_indices = opset9.zeros(g, gather_shape, 4, None, None)

    # There might be cases where offset value is greater than number of rows/columns
    # and might cause the diagonal to overrun and as a result of this, diag_size would be zero.
    # For example, if
    #       offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows)
    #       diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above
    # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0
    # In cases without diagonal overrun, we select the appropriate rows/columns along which we
    # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has
    # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially
    # returning an empty tensor
    overrun_cond = g.op(
        "Not",
        g.op(
            "Equal",
            diag_size,
            g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)),
        ),
    )
    if_op = g.op("If", overrun_cond)
    if_node = if_op.node()

    if_block = utils._add_block(if_node)
    gather_indices_if_block = if_block.op("Add", gather_indices, select_window)
    gather_indices_if_block = symbolic_helper._unsqueeze_helper(
        if_block, gather_indices_if_block, [rank - 1]
    )
    final_non_overrun_ = if_block.op(
        "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2
    )
    utils._add_output_to_block(if_block, final_non_overrun_)

    else_block = utils._add_block(if_node)
    final_overrun_ = opset9.zeros(else_block, gather_shape, 6, None, None)
    utils._add_output_to_block(else_block, final_overrun_)
    return if_op
Пример #9
0
def embedding_bag(g, embedding_matrix, indices, offsets, scale_grad_by_freq,
                  mode, sparse, per_sample_weights, include_last_offset):
    if scale_grad_by_freq and sym_help._training_mode:
        return sym_help._onnx_unsupported(
            'embedding_bag with scale_grad_by_freq for training mode')

    from torch.onnx.symbolic_opset9 import size, div, select

    # Check if initial indices was 2D. In functional.py:
    # offsets is set to torch.arange(0, indices.numel(), indices.size(1))
    # Then indices is reshaped to 1D: indices.reshape(-1)
    if len(list(indices.node().inputs())) > 0 and indices.node().inputs().__next__().type().sizes() is not None \
            and len(indices.node().inputs().__next__().type().sizes()) == 2:
        # Assert include_last_offset is False
        assert not include_last_offset
        embeddings = g.op("Gather", embedding_matrix, indices)
        dim_0 = size(g, offsets, g.op("Constant",
                                      value_t=torch.LongTensor([0])))
        dim_1 = div(
            g, size(g, indices, g.op("Constant",
                                     value_t=torch.LongTensor([0]))), dim_0)
        dim_2 = g.op("Constant", value_t=torch.LongTensor([-1]))

        shape = [dim_0, dim_1, dim_2]
        shape = g.op("Concat", *shape, axis_i=0)

        if not sym_help._is_none(per_sample_weights):
            per_sample_weights = g.op("Unsqueeze",
                                      per_sample_weights,
                                      axes_i=[1])
            embeddings = g.op("Mul", embeddings, per_sample_weights)

        embeddings = g.op("Reshape", embeddings, shape)
        if mode == 0:
            embeddings = g.op("ReduceSum",
                              embeddings,
                              axes_i=[1],
                              keepdims_i=0)
        elif mode == 1:
            embeddings = g.op("ReduceMean",
                              embeddings,
                              axes_i=[1],
                              keepdims_i=0)
        else:
            embeddings = g.op("ReduceMax",
                              embeddings,
                              axes_i=[1],
                              keepdims_i=0)
        # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
        # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
        return embeddings, None, None, None
    elif offsets.type().sizes() is not None:
        if include_last_offset:
            offset_len = offsets.type().sizes()[0] - 1
            offsets_extended = offsets
        else:
            offset_len = offsets.type().sizes()[0]
            offsets_extended = [
                offsets,
                g.op("Constant", value_t=torch.tensor([maxsize]))
            ]
            offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
        list_ = []
        for i in range(offset_len):
            start_ = g.op("Unsqueeze",
                          select(g, offsets_extended, torch.tensor(0),
                                 torch.tensor(i)),
                          axes_i=[0])
            end_ = g.op("Unsqueeze",
                        select(g, offsets_extended, torch.tensor(0),
                               torch.tensor(i + 1)),
                        axes_i=[0])
            axes_ = g.op("Constant", value_t=torch.tensor([0]))
            indices_row = g.op("Slice", indices, start_, end_, axes_)

            embeddings = g.op("Gather", embedding_matrix, indices_row)
            if not sym_help._is_none(per_sample_weights):
                per_sample_weights_row = g.op("Slice", per_sample_weights,
                                              start_, end_, axes_)
                per_sample_weights_row = g.op("Unsqueeze",
                                              per_sample_weights_row,
                                              axes_i=[1])
                embeddings = g.op("Mul", embeddings, per_sample_weights_row)
            if mode == 0:
                embeddings = g.op("ReduceSum",
                                  embeddings,
                                  axes_i=[0],
                                  keepdims_i=0)
            elif mode == 1:
                embeddings = g.op("ReduceMean",
                                  embeddings,
                                  axes_i=[0],
                                  keepdims_i=0)
            else:
                embeddings = g.op("ReduceMax",
                                  embeddings,
                                  axes_i=[0],
                                  keepdims_i=0)

            embeddings = g.op("Unsqueeze", embeddings, axes_i=[0])
            list_.append(embeddings)

        output = g.op("Concat", *list_, axis_i=0)
        # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
        # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
        return output, None, None, None
    else:
        return sym_help._onnx_unsupported(
            'embedding_bag with unknown shape of indices')