예제 #1
0
def _flatten_helper(g, input, start_dim, end_dim, dim):
    input_size = g.op("Shape", input)
    slice1 = _slice_helper(g,
                           input_size,
                           axes=[0],
                           starts=[0],
                           ends=[start_dim])
    slices = [
        slice1,
        g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))
    ]
    if end_dim < dim - 1:
        slice3 = _slice_helper(g,
                               input_size,
                               axes=[0],
                               starts=[end_dim + 1],
                               ends=[dim])
        slices = [
            slice1,
            g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
            slice3
        ]

    final_shape = g.op("Concat", *slices, axis_i=0)
    from torch.onnx.symbolic_opset9 import _reshape_from_tensor
    return _reshape_from_tensor(g, input, final_shape)
예제 #2
0
def flatten(g, input, start_dim, end_dim):
    dim = input.type().dim()
    # use ONNX's Flatten operator for cases where the output shape is 2D
    if start_dim == 1:
        if (end_dim == -1 or (end_dim is not None and end_dim == dim - 1)):
            return g.op("Flatten", input, axis_i=start_dim)
    elif start_dim == 0:
        if (end_dim == -2 or (end_dim is not None and end_dim == dim - 2)):
            return g.op("Flatten", input, axis_i=end_dim + 1)
    # use Reshape for cases where the output shape is not 2D
    if not input.isCompleteTensor():
        return _unimplemented(
            "flatten", "input size not accessible "
            "(consider using reshape op instead of flatten op to export to ONNX)"
        )
    # if end_dim is negative add dim
    if end_dim < 0:
        end_dim = dim + end_dim
    input_dims = input.type().sizes()
    output_dims = []
    for i in range(0, dim):
        if start_dim < i and end_dim >= i:
            output_dims[start_dim] = output_dims[start_dim] * input_dims[i]
        else:
            output_dims.append(input_dims[i])
    shape = g.op("Constant", value_t=torch.LongTensor(output_dims))
    from torch.onnx.symbolic_opset9 import _reshape_from_tensor
    p = _reshape_from_tensor(g, input, shape)
    return p
예제 #3
0
def tensordot(g, input_a, input_b, dims_a, dims_b, out=None):
    if out is not None:
        _unimplemented("Tensordot", "Out parameter is not supported for tensordot.")

    dim_count_a = sym_help._get_tensor_rank(input_a)
    if dim_count_a is None:
        raise RuntimeError("Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.")

    dim_count_b = sym_help._get_tensor_rank(input_b)
    if dim_count_b is None:
        raise RuntimeError("Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.")

    dims_a = [(dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i] for i in range(len(dims_a))]
    dims_b = [(dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i] for i in range(len(dims_b))]

    left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)]
    left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)]

    new_input_a = permute(g, input_a, left_dims_a + dims_a)
    new_input_b = permute(g, input_b, dims_b + left_dims_b)

    input_shape = g.op("Shape", new_input_a)
    left_sizes_a = sym_help._slice_helper(g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)])
    shape_sizes = [left_sizes_a, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))]
    output_a = _reshape_from_tensor(g, new_input_a, shape_sizes)

    input_shape = g.op("Shape", output_a)
    slices = sym_help._slice_helper(g, input_shape, axes=[0], starts=[-1], ends=[maxsize])
    shape_sizes = [g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices]
    output_a = _reshape_from_tensor(g, new_input_a, shape_sizes)

    input_shape = g.op("Shape", new_input_b)
    left_sizes_b = sym_help._slice_helper(g, input_shape, axes=[0], starts=[len(dims_b)], ends=[maxsize])
    slices = sym_help._slice_helper(g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)])
    shape_sizes = [slices, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))]
    output_b = _reshape_from_tensor(g, new_input_b, shape_sizes)

    input_shape = g.op("Shape", output_b)
    slices = sym_help._slice_helper(g, input_shape, axes=[0], starts=[-1], ends=[maxsize])
    shape_sizes = [g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slices]
    output_b = _reshape_from_tensor(g, new_input_b, shape_sizes)

    output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b]))

    shape_sizes = [left_sizes_a, left_sizes_b]
    return _reshape_from_tensor(g, output, shape_sizes)