Example #1
0
def Split(onnx_node, ng_inputs):  # type: (NodeWrapper, List[NgraphNode]) -> Tuple[NgraphNode, ...]
    """Split a tensor into a list of tensors."""
    data = ng_inputs[0]
    count_outputs = len(onnx_node.get_output_names())
    axis_to_split = onnx_node.get_attribute_value('axis', 0)

    if axis_to_split < 0 or axis_to_split >= len(data.shape):
        raise ValueError('Split node (%s) provided split axis is out of input tensor dimensions'
                         ' range.', onnx_node.name)

    len_axis_to_split = data.shape[axis_to_split]
    len_parts = onnx_node.get_attribute_value('split')

    if len_parts is None:
        if len_axis_to_split % count_outputs:
            raise ValueError('Split node (%s): Tensor cannot be split into %d equal parts, along '
                             'axis of length %d', onnx_node.name, count_outputs, len_axis_to_split)
        len_parts = [int(len_axis_to_split / count_outputs)] * count_outputs
    elif sum(len_parts) != len_axis_to_split:
        raise ValueError('Split node (%s): provided lengths of split parts does not sum up to '
                         'length of axis we split on: %d != %d', onnx_node.name, sum(len_parts),
                         len_axis_to_split)

    outputs = []
    start_index = 0

    for len_part in len_parts:
        end_index = start_index + len_part
        outputs.append(make_slice_op(data, [axis_to_split], [start_index], [end_index]))
        start_index = end_index

    return tuple(outputs)
Example #2
0
def Slice(onnx_node, ng_inputs):  # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode
    """Produce a slice of the input tensor along multiple axes."""
    input_node = ng_inputs[0]

    starts = onnx_node.get_attribute_value('starts')
    ends = onnx_node.get_attribute_value('ends')
    if not (starts and ends and len(starts) == len(ends)):
        raise ValueError('Slice node (%s): attributes `starts` and `ends` must be set '
                         'and of equal length.', onnx_node.name)

    axes = onnx_node.get_attribute_value('axes')
    if axes is None:
        axes = list(range(len(starts)))
    else:
        for axis in axes:
            if axis < 0 or axis > len(input_node.shape) - 1:
                raise ValueError('Slice node (%s): specified axes are out of node\' dimensions '
                                 'bounds', onnx_node.name)

    return make_slice_op(input_node, axes, starts, ends)