Ejemplo n.º 1
0
def matmul(left, right, transpose_a=False, transpose_b=False, name=None):
    """
    Only support 2d matmul for now.
    """
    # Transpose
    if transpose_a:
        left = ng.Transpose(left)
    if transpose_b:
        right = ng.Transpose(right)

    # Check shape
    assert len(left.axes) == len(right.axes) == 2
    assert left.axes[1].length == right.axes[0].length

    # step 1: cast left (pos_1, pos_0), right (pos_1, pos_0) =>
    #              left (temp , pos_1), right (pos_1, pos_0)
    # step 2: perform left dot right, result
    #         (temp, pos_0)
    # step 3: cast back to (post_1, pos_0)
    left_temp_axes = ng.make_axes(
        [ng.make_axis(left.axes[0].length), right.axes[0]])
    left = ng.cast_axes(left, axes=left_temp_axes)

    # Result op
    result_op = ng.dot(left, right).named(name)
    result_op = cast_to_pos_axes(result_op)

    # Return
    return result_op.named(name)
Ejemplo n.º 2
0
def Gemm(onnx_node, ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    # Y = alpha * (A @ B) + beta * C
    input_a, input_b, input_c = ng_inputs
    alpha = onnx_node.get_attribute_value('alpha',
                                          1)  # Scalar multiplier for A @ B
    beta = onnx_node.get_attribute_value(
        'beta', 1)  # Scalar multiplier for input tensor C
    broadcast = onnx_node.get_attribute_value('broadcast',
                                              1)  # Should C be broadcast?
    trans_a = onnx_node.get_attribute_value('transA',
                                            False)  # Should A be transposed?
    trans_b = onnx_node.get_attribute_value('transB',
                                            False)  # Should B be transposed?

    if not broadcast:
        logger.warning(
            'Gemm node (%s): import does not support broadcast value %s',
            onnx_node.name, broadcast)

    if trans_a:
        input_a = ng.Transpose(input_a)

    if trans_b:
        input_b = ng.Transpose(input_b)

    input_a, input_b = cast_axes_for_matmul(input_a, input_b)
    a_dot_b = ng.dot(input_a, input_b)
    a_dot_b = cast_to_pos_axes(a_dot_b)
    return alpha * a_dot_b + beta * input_c
Ejemplo n.º 3
0
def BatchNormalization(onnx_node,
                       ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    x, scale, bias, mean, var = ng_inputs

    is_test = onnx_node.get_attribute_value('is_test', 1)
    spatial = onnx_node.get_attribute_value('spatial', 1)
    epsilon = onnx_node.get_attribute_value('epsilon', 1e-3)
    # @TODO: Implement learning mode support
    # momentum = onnx_node.get_attribute_value('momentum', 0.99)

    if not is_test:
        raise NotImplementedError(
            'BatchNormalization node (%s): only `is_test` mode is currently '
            'supported.', onnx_node.name)
    if not spatial:
        raise NotImplementedError(
            'BatchNormalization node (%s): only `spatial` mode is currently '
            'supported.', onnx_node.name)

    if len(x.axes) == 5:
        x = rename_axes(x, 'NCHWD')
    else:
        x = rename_axes(x, 'NCHW')

    mean = rename_axes(mean, 'C')
    scale = rename_axes(scale, 'C')
    bias = rename_axes(bias, 'C')
    var = rename_axes(var, 'C')

    ng_op = ng.unflatten(scale *
                         ((x - mean) * ng.reciprocal(ng.sqrt(var + epsilon))) +
                         bias)

    return cast_to_pos_axes(ng_op)
Ejemplo n.º 4
0
def Slice(onnx_node, ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    """Produce a slice of the input tensor along multiple axes."""
    x = ng_inputs[0]

    starts = onnx_node.get_attribute_value('starts')
    ends = onnx_node.get_attribute_value('ends')
    if not (starts and ends and len(starts) == len(ends)):
        raise ValueError(
            'Slice node (%s): attributes `starts` and `ends` must be set '
            'and of equal length.', onnx_node.name)

    axes = onnx_node.get_attribute_value('axes', list(range(len(starts))))
    slices_count = max(len(axes), *starts)
    if slices_count > len(x.axes):
        raise ValueError(
            'Slice node (%s): specifies %d slices, there are only %d input axes.',
            onnx_node.name, slices_count, len(x.axes))

    slices = [
        slice(starts[axes.index(axis_number)], ends[axes.index(axis_number)])
        if (axis_number in axes) else slice(None)
        for axis_number in range(len(x.axes))
    ]

    return cast_to_pos_axes(ng.tensor_slice(x, slices))
Ejemplo n.º 5
0
def Flatten(onnx_node, ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    """Flatten the input tensor into a 2D matrix."""
    data = ng_inputs[0]
    axis = onnx_node.get_attribute_value('axis', 1)

    if axis < 0 or axis > len(data.axes):
        raise ValueError(
            'Flatten node (%s): %d is not a valid value for `axis`.',
            onnx_node.name, axis)

    return cast_to_pos_axes(ng.flatten_at(data, axis))
Ejemplo n.º 6
0
def Pad(onnx_node, ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    pads = onnx_node.get_attribute_value('pads')
    constant = 'constant'
    mode = onnx_node.get_attribute_value(
        'mode', constant)  # 'constant', 'reflect' or 'edge'
    value = onnx_node.get_attribute_value('value', 0)

    if mode != constant or value != 0:
        raise NotImplementedError(
            'Pad node (%s): only constant padding with value=0 '
            'is supported.', onnx_node.name)

    # Split paddings into pairs for each axis
    pads = [pad for pad in split_pads_into_pairs(pads)]
    return cast_to_pos_axes(ng.pad(ng_inputs[0], pads))
Ejemplo n.º 7
0
def _reduction(input_tensor, ng_op, axis=None, keep_dims=False, name=None):
    """
    Args:
        axis: int or list of ints
    """
    if keep_dims:
        raise NotImplementedError("ngraph only support keep_dims=True now.")

    if axis is None:
        ng_reduction_axes = input_tensor.axes
    else:
        try:
            iter(axis)
        except TypeError:
            axis = list(axis)
        ng_reduction_axes = ng.make_axes(
            [input_tensor.axes[ind] for ind in axis])
    res = ng_op(input_tensor, reduction_axes=ng_reduction_axes)
    return cast_to_pos_axes(res).named(name)
Ejemplo n.º 8
0
def Transpose(onnx_node,
              ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    """Transpose the input tensor similar to numpy.transpose.

    By default, reverse the dimensions, but if `perm` attribute is specified
    permute the axes according to the values given.
    """
    data = ng_inputs[0]
    permute_axes = onnx_node.get_attribute_value('perm')

    if permute_axes:
        input_template = ''.join(
            [ascii_letters[i] for i in range(len(data.axes))])
        output_template = ''.join([ascii_letters[i] for i in permute_axes])
        ng_op = reorder_axes(data, input_template, output_template)
    else:
        ng_op = ng.Transpose(data)

    return cast_to_pos_axes(ng_op)
Ejemplo n.º 9
0
def sparse_softmax_cross_entropy_with_logits(labels=None,
                                             logits=None,
                                             name=None):
    """
    Computes softmax cross entropy. The inputs `logits` are unscaled log
    probabilities, and each row of `labels[i]` must be a valid distribution.

    Args:
        labels: of axis (N,) for (POS_0,)
        logits: of axis (N, Y) for (POS_1, POS_0)
        name: name of the ngraph op
    """
    # Check input dimension
    #         (    N,     Y),         (    N)
    # logits: (pos_1, pos_0), labels: (pos_0)
    try:
        assert len(logits.axes) == 2
        assert len(labels.axes) == 1
        assert logits.axes[0].length == labels.axes[0].length
    except:
        raise NotImplementedError("logits' shape must be (N, Y), "
                                  "labels' shape must be (N,), "
                                  "other shapes not supported yet.")
    # get axis
    axis_n, axis_y = logits.axes

    # convert labels to one-hot labels
    labels = ng.cast_axes(labels, ng.make_axes(axis_n))
    labels = ng.one_hot(labels, axis=axis_y)
    labels = ng.axes_with_order(labels, axes=logits.axes)

    # predicts: (N, Y)
    predicts = ng.softmax(logits, normalization_axes=axis_y)

    # cross_entropy: (N)
    res = ng.cross_entropy_multi(predicts, labels, out_axes=(axis_n, ))
    return cast_to_pos_axes(res).named(name)
Ejemplo n.º 10
0
def make_convolution_op(onnx_node, ng_inputs, transpose=False):
    # type: (NodeWrapper, List[TensorOp], bool) -> Op
    """
    Create an ngraph convolution or deconvolution Op based on an ONNX node.

    :param onnx_node: wrapped ONNX node for Conv of ConvTranspose op
    :param ng_inputs: ngraph TensorOp input tensors
    :param transpose: should this be a transposed convolution?
    :return: ngraph Op for convolution or deconvolution
    """
    if len(ng_inputs) == 3:
        x, weights, bias = ng_inputs
    elif len(ng_inputs) == 2:
        x, weights = ng_inputs
        bias = ng.constant(0)
    else:
        raise ValueError(
            'Conv node (%s): unexpected number of input values: %d.',
            onnx_node.name, len(ng_inputs))

    # Reorder x axes from ONNX convention (N, C, H, W, D) to ngraph (C, D, H, W, N)
    # Reorder weights axes from ONNX (K, J, R, S, T) to ngraph (J, T, R, S, K)
    # Axis names follow https://ngraph.nervanasys.com/index.html/axes.html
    if len(x.axes) == 4:  # 2D convolution
        x = reorder_axes(x, 'NCHW', 'CDHWN')
        weights = reorder_axes(weights, 'KJRS', 'JTRSK')
    elif len(x.axes) == 5:  # 3D convolution
        x = reorder_axes(x, 'NCHWD', 'CDHWN')
        weights = reorder_axes(weights, 'KJRST', 'JTRSK')
    else:
        raise NotImplementedError(
            'Conv node (%s): only 2D and 3D convolutions are supported.',
            onnx_node.name)

    groups = onnx_node.get_attribute_value('group', 1)
    if groups != 1:
        raise NotImplementedError(
            'Conv node (%s): `group` attribute value %d not supported.',
            onnx_node.name, groups)

    # Prepare ngraph convolution operation
    conv_params = get_conv_params(onnx_node)
    output_axes = make_conv_output_axes(x, weights, conv_params)

    if transpose:
        conv = ng.deconvolution(conv_params, x, weights, axes=output_axes)

    else:
        conv = ng.convolution(conv_params, x, weights, axes=output_axes)

    conv = cast_to_pos_axes(conv) + bias

    # ONNX output should have axes in the order N, C, H, W, D
    conv = reorder_axes(conv, 'CDHWN', 'NCHWD')

    if len(ng_inputs[0].axes
           ) == 4:  # 2D convolution, slice away the D axis from output
        conv = ng.tensor_slice(conv, [
            slice(None), slice(None),
            slice(None), slice(None), 0
        ])

    return conv
Ejemplo n.º 11
0
def Constant(onnx_node,
             ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    value_tensor = onnx_node.get_attribute_value('value')
    return cast_to_pos_axes(ng.constant(value_tensor.to_array()))
Ejemplo n.º 12
0
def GlobalAveragePool(onnx_node,
                      ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    """Equivalent to AveragePool with kernel size equal to spatial dimensions of input tensor."""
    return cast_to_pos_axes(make_global_pooling_op(onnx_node, ng_inputs))
Ejemplo n.º 13
0
def MaxPool(onnx_node, ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    return cast_to_pos_axes(make_pooling_op(onnx_node, ng_inputs))
Ejemplo n.º 14
0
def ConvTranspose(onnx_node,
                  ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    return cast_to_pos_axes(
        make_convolution_op(onnx_node, ng_inputs, transpose=True))
Ejemplo n.º 15
0
def MatMul(onnx_node, ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    left, right = cast_axes_for_matmul(*ng_inputs)
    return cast_to_pos_axes(ng.dot(left, right))