def matmul(left, right, transpose_a=False, transpose_b=False, name=None): """ Only support 2d matmul for now. """ # Transpose if transpose_a: left = ng.Transpose(left) if transpose_b: right = ng.Transpose(right) # Check shape assert len(left.axes) == len(right.axes) == 2 assert left.axes[1].length == right.axes[0].length # step 1: cast left (pos_1, pos_0), right (pos_1, pos_0) => # left (temp , pos_1), right (pos_1, pos_0) # step 2: perform left dot right, result # (temp, pos_0) # step 3: cast back to (post_1, pos_0) left_temp_axes = ng.make_axes( [ng.make_axis(left.axes[0].length), right.axes[0]]) left = ng.cast_axes(left, axes=left_temp_axes) # Result op result_op = ng.dot(left, right).named(name) result_op = cast_to_pos_axes(result_op) # Return return result_op.named(name)
def Gemm(onnx_node, ng_inputs): # type: (NodeWrapper, List[TensorOp]) -> Op # Y = alpha * (A @ B) + beta * C input_a, input_b, input_c = ng_inputs alpha = onnx_node.get_attribute_value('alpha', 1) # Scalar multiplier for A @ B beta = onnx_node.get_attribute_value( 'beta', 1) # Scalar multiplier for input tensor C broadcast = onnx_node.get_attribute_value('broadcast', 1) # Should C be broadcast? trans_a = onnx_node.get_attribute_value('transA', False) # Should A be transposed? trans_b = onnx_node.get_attribute_value('transB', False) # Should B be transposed? if not broadcast: logger.warning( 'Gemm node (%s): import does not support broadcast value %s', onnx_node.name, broadcast) if trans_a: input_a = ng.Transpose(input_a) if trans_b: input_b = ng.Transpose(input_b) input_a, input_b = cast_axes_for_matmul(input_a, input_b) a_dot_b = ng.dot(input_a, input_b) a_dot_b = cast_to_pos_axes(a_dot_b) return alpha * a_dot_b + beta * input_c
def BatchNormalization(onnx_node, ng_inputs): # type: (NodeWrapper, List[TensorOp]) -> Op x, scale, bias, mean, var = ng_inputs is_test = onnx_node.get_attribute_value('is_test', 1) spatial = onnx_node.get_attribute_value('spatial', 1) epsilon = onnx_node.get_attribute_value('epsilon', 1e-3) # @TODO: Implement learning mode support # momentum = onnx_node.get_attribute_value('momentum', 0.99) if not is_test: raise NotImplementedError( 'BatchNormalization node (%s): only `is_test` mode is currently ' 'supported.', onnx_node.name) if not spatial: raise NotImplementedError( 'BatchNormalization node (%s): only `spatial` mode is currently ' 'supported.', onnx_node.name) if len(x.axes) == 5: x = rename_axes(x, 'NCHWD') else: x = rename_axes(x, 'NCHW') mean = rename_axes(mean, 'C') scale = rename_axes(scale, 'C') bias = rename_axes(bias, 'C') var = rename_axes(var, 'C') ng_op = ng.unflatten(scale * ((x - mean) * ng.reciprocal(ng.sqrt(var + epsilon))) + bias) return cast_to_pos_axes(ng_op)
def Slice(onnx_node, ng_inputs): # type: (NodeWrapper, List[TensorOp]) -> Op """Produce a slice of the input tensor along multiple axes.""" x = ng_inputs[0] starts = onnx_node.get_attribute_value('starts') ends = onnx_node.get_attribute_value('ends') if not (starts and ends and len(starts) == len(ends)): raise ValueError( 'Slice node (%s): attributes `starts` and `ends` must be set ' 'and of equal length.', onnx_node.name) axes = onnx_node.get_attribute_value('axes', list(range(len(starts)))) slices_count = max(len(axes), *starts) if slices_count > len(x.axes): raise ValueError( 'Slice node (%s): specifies %d slices, there are only %d input axes.', onnx_node.name, slices_count, len(x.axes)) slices = [ slice(starts[axes.index(axis_number)], ends[axes.index(axis_number)]) if (axis_number in axes) else slice(None) for axis_number in range(len(x.axes)) ] return cast_to_pos_axes(ng.tensor_slice(x, slices))
def Flatten(onnx_node, ng_inputs): # type: (NodeWrapper, List[TensorOp]) -> Op """Flatten the input tensor into a 2D matrix.""" data = ng_inputs[0] axis = onnx_node.get_attribute_value('axis', 1) if axis < 0 or axis > len(data.axes): raise ValueError( 'Flatten node (%s): %d is not a valid value for `axis`.', onnx_node.name, axis) return cast_to_pos_axes(ng.flatten_at(data, axis))
def Pad(onnx_node, ng_inputs): # type: (NodeWrapper, List[TensorOp]) -> Op pads = onnx_node.get_attribute_value('pads') constant = 'constant' mode = onnx_node.get_attribute_value( 'mode', constant) # 'constant', 'reflect' or 'edge' value = onnx_node.get_attribute_value('value', 0) if mode != constant or value != 0: raise NotImplementedError( 'Pad node (%s): only constant padding with value=0 ' 'is supported.', onnx_node.name) # Split paddings into pairs for each axis pads = [pad for pad in split_pads_into_pairs(pads)] return cast_to_pos_axes(ng.pad(ng_inputs[0], pads))
def _reduction(input_tensor, ng_op, axis=None, keep_dims=False, name=None): """ Args: axis: int or list of ints """ if keep_dims: raise NotImplementedError("ngraph only support keep_dims=True now.") if axis is None: ng_reduction_axes = input_tensor.axes else: try: iter(axis) except TypeError: axis = list(axis) ng_reduction_axes = ng.make_axes( [input_tensor.axes[ind] for ind in axis]) res = ng_op(input_tensor, reduction_axes=ng_reduction_axes) return cast_to_pos_axes(res).named(name)
def Transpose(onnx_node, ng_inputs): # type: (NodeWrapper, List[TensorOp]) -> Op """Transpose the input tensor similar to numpy.transpose. By default, reverse the dimensions, but if `perm` attribute is specified permute the axes according to the values given. """ data = ng_inputs[0] permute_axes = onnx_node.get_attribute_value('perm') if permute_axes: input_template = ''.join( [ascii_letters[i] for i in range(len(data.axes))]) output_template = ''.join([ascii_letters[i] for i in permute_axes]) ng_op = reorder_axes(data, input_template, output_template) else: ng_op = ng.Transpose(data) return cast_to_pos_axes(ng_op)
def sparse_softmax_cross_entropy_with_logits(labels=None, logits=None, name=None): """ Computes softmax cross entropy. The inputs `logits` are unscaled log probabilities, and each row of `labels[i]` must be a valid distribution. Args: labels: of axis (N,) for (POS_0,) logits: of axis (N, Y) for (POS_1, POS_0) name: name of the ngraph op """ # Check input dimension # ( N, Y), ( N) # logits: (pos_1, pos_0), labels: (pos_0) try: assert len(logits.axes) == 2 assert len(labels.axes) == 1 assert logits.axes[0].length == labels.axes[0].length except: raise NotImplementedError("logits' shape must be (N, Y), " "labels' shape must be (N,), " "other shapes not supported yet.") # get axis axis_n, axis_y = logits.axes # convert labels to one-hot labels labels = ng.cast_axes(labels, ng.make_axes(axis_n)) labels = ng.one_hot(labels, axis=axis_y) labels = ng.axes_with_order(labels, axes=logits.axes) # predicts: (N, Y) predicts = ng.softmax(logits, normalization_axes=axis_y) # cross_entropy: (N) res = ng.cross_entropy_multi(predicts, labels, out_axes=(axis_n, )) return cast_to_pos_axes(res).named(name)
def make_convolution_op(onnx_node, ng_inputs, transpose=False): # type: (NodeWrapper, List[TensorOp], bool) -> Op """ Create an ngraph convolution or deconvolution Op based on an ONNX node. :param onnx_node: wrapped ONNX node for Conv of ConvTranspose op :param ng_inputs: ngraph TensorOp input tensors :param transpose: should this be a transposed convolution? :return: ngraph Op for convolution or deconvolution """ if len(ng_inputs) == 3: x, weights, bias = ng_inputs elif len(ng_inputs) == 2: x, weights = ng_inputs bias = ng.constant(0) else: raise ValueError( 'Conv node (%s): unexpected number of input values: %d.', onnx_node.name, len(ng_inputs)) # Reorder x axes from ONNX convention (N, C, H, W, D) to ngraph (C, D, H, W, N) # Reorder weights axes from ONNX (K, J, R, S, T) to ngraph (J, T, R, S, K) # Axis names follow https://ngraph.nervanasys.com/index.html/axes.html if len(x.axes) == 4: # 2D convolution x = reorder_axes(x, 'NCHW', 'CDHWN') weights = reorder_axes(weights, 'KJRS', 'JTRSK') elif len(x.axes) == 5: # 3D convolution x = reorder_axes(x, 'NCHWD', 'CDHWN') weights = reorder_axes(weights, 'KJRST', 'JTRSK') else: raise NotImplementedError( 'Conv node (%s): only 2D and 3D convolutions are supported.', onnx_node.name) groups = onnx_node.get_attribute_value('group', 1) if groups != 1: raise NotImplementedError( 'Conv node (%s): `group` attribute value %d not supported.', onnx_node.name, groups) # Prepare ngraph convolution operation conv_params = get_conv_params(onnx_node) output_axes = make_conv_output_axes(x, weights, conv_params) if transpose: conv = ng.deconvolution(conv_params, x, weights, axes=output_axes) else: conv = ng.convolution(conv_params, x, weights, axes=output_axes) conv = cast_to_pos_axes(conv) + bias # ONNX output should have axes in the order N, C, H, W, D conv = reorder_axes(conv, 'CDHWN', 'NCHWD') if len(ng_inputs[0].axes ) == 4: # 2D convolution, slice away the D axis from output conv = ng.tensor_slice(conv, [ slice(None), slice(None), slice(None), slice(None), 0 ]) return conv
def Constant(onnx_node, ng_inputs): # type: (NodeWrapper, List[TensorOp]) -> Op value_tensor = onnx_node.get_attribute_value('value') return cast_to_pos_axes(ng.constant(value_tensor.to_array()))
def GlobalAveragePool(onnx_node, ng_inputs): # type: (NodeWrapper, List[TensorOp]) -> Op """Equivalent to AveragePool with kernel size equal to spatial dimensions of input tensor.""" return cast_to_pos_axes(make_global_pooling_op(onnx_node, ng_inputs))
def MaxPool(onnx_node, ng_inputs): # type: (NodeWrapper, List[TensorOp]) -> Op return cast_to_pos_axes(make_pooling_op(onnx_node, ng_inputs))
def ConvTranspose(onnx_node, ng_inputs): # type: (NodeWrapper, List[TensorOp]) -> Op return cast_to_pos_axes( make_convolution_op(onnx_node, ng_inputs, transpose=True))
def MatMul(onnx_node, ng_inputs): # type: (NodeWrapper, List[TensorOp]) -> Op left, right = cast_axes_for_matmul(*ng_inputs) return cast_to_pos_axes(ng.dot(left, right))