Пример #1
0
    def cast_axes_for_compound_op(self, inputs):
        left, right = inputs

        left_dim = len(left.axes)
        right_dim = len(right.axes)

        # pad left and right axis to be the same length, align right
        result_dim = max(left_dim, right_dim)
        left_axes_pad = [
            ng.make_axis(length=1) for _ in range(result_dim - left_dim)
        ] + list(left.axes)
        right_axes_pad = [
            ng.make_axis(length=1) for _ in range(result_dim - right_dim)
        ] + list(right.axes)
        result_axes = [
            ng.make_axis(length=max(l.length, r.length))
            for l, r in zip(left_axes_pad, right_axes_pad)
        ]

        # broadcast left / right, introducing dummy length 1 axes
        left = ng.broadcast(left, left_axes_pad)
        right = ng.broadcast(right, right_axes_pad)

        # make two-way map of lr matching axes and map for result axes
        lr_axes_map = dict()
        result_axes_map = dict()
        for l, r, re in zip(left.axes, right.axes, result_axes):
            lr_axes_map[l] = r
            lr_axes_map[r] = l
            result_axes_map[l] = re
            result_axes_map[r] = re

        # get left / right slice
        left_slice = []
        right_slice = []
        for l, r in zip(left.axes, right.axes):
            if l.length == 1 and r.length != 1:
                left_slice.append(0)
            else:
                left_slice.append(slice(None))
            if r.length == 1 and l.length != 1:
                right_slice.append(0)
            else:
                right_slice.append(slice(None))

        # perform slicing
        left_sliced = ng.tensor_slice(left, left_slice)
        right_sliced = ng.tensor_slice(right, right_slice)

        # now cast the right_sliced to left_sliced from the axis map
        right_casted_axes = []
        for r in right_sliced.axes:
            if r in lr_axes_map and lr_axes_map[r] in left_sliced.axes:
                right_casted_axes.append(lr_axes_map[r])
            else:
                right_casted_axes.append(r)
        right_sliced_casted = ng.cast_axes(right_sliced, right_casted_axes)

        return left_sliced, right_sliced_casted
Пример #2
0
def test_reverse_slice(transformer_factory):
    """TODO."""

    C = ng.make_axis(length=10)
    D = ng.make_axis(length=10)

    x = ng.placeholder([C, D])
    with pytest.raises(ValueError):
        ng.tensor_slice(x, [slice(0, 10, -1), slice(0, 10)])

    with pytest.raises(ValueError):
        ng.set_item(x, [0, slice(None, None, -1)], 0)
Пример #3
0
def Slice(onnx_node, ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    """Produce a slice of the input tensor along multiple axes."""
    x = ng_inputs[0]

    starts = onnx_node.get_attribute_value('starts')
    ends = onnx_node.get_attribute_value('ends')
    if not (starts and ends and len(starts) == len(ends)):
        raise ValueError(
            'Slice node (%s): attributes `starts` and `ends` must be set '
            'and of equal length.', onnx_node.name)

    axes = onnx_node.get_attribute_value('axes', list(range(len(starts))))
    slices_count = max(len(axes), *starts)
    if slices_count > len(x.axes):
        raise ValueError(
            'Slice node (%s): specifies %d slices, there are only %d input axes.',
            onnx_node.name, slices_count, len(x.axes))

    slices = [
        slice(starts[axes.index(axis_number)], ends[axes.index(axis_number)])
        if (axis_number in axes) else slice(None)
        for axis_number in range(len(x.axes))
    ]

    return cast_to_pos_axes(ng.tensor_slice(x, slices))
Пример #4
0
def test_conv_flatten_deriv(n4_hw12_c3_5x5):
    """
    Test deriv of conv followed by flatten
    """
    cf = ConvParams(**n4_hw12_c3_5x5)

    axes_rsck = ng.make_axes([cf.ax_f[2], cf.ax_f[3], cf.ax_f[0], cf.ax_f[-1]])
    axes_rsck_prime = ng.make_axes([ng.make_axis(name=ax.name + 'p', length=ax.length)
                                    for ax in axes_rsck])
    axes_nmpqk = ng.make_axes([cf.ax_o[-1], cf.ax_o[1], cf.ax_o[2], cf.ax_o[3], cf.ax_o[0]])

    # broadcast input / filter axes
    input_var = ng.variable(cf.ax_i).named('input')
    input_val = np.ones(input_var.axes.lengths)

    filter_rsck_prime = ng.variable(axes_rsck_prime).named('filter')
    filter_var = filter_rsck_prime
    filter_rsck = ng.cast_axes(filter_rsck_prime, axes_rsck).named('frsck')
    filter_trsck = ng.expand_dims(filter_rsck, cf.ax_f[1], 0).named('ftrsck')
    filter_ctrsk = ng.axes_with_order(filter_trsck, axes=cf.ax_f).named('ctrsk')

    # convolution
    output_kmpqn = ng.convolution(cf.conv_params, input_var, filter_ctrsk, axes=cf.ax_o)
    output_nmpqk = ng.axes_with_order(output_kmpqn, axes=axes_nmpqk)

    # slice away the oD
    out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)]
    output_npqk = ng.tensor_slice(output_nmpqk, out_slicing)

    output = ng.flatten_at(output_npqk, idx=1)

    # cost and grad
    cost = ng.sum(output, out_axes=())

    filter_val = np.ones(filter_var.axes.lengths)

    with ExecutorFactory() as factory:

        conv_comp = factory.executor(output, filter_var, input_var)
        grad_filter_num_comp = factory.numeric_derivative(cost, filter_var, 1.0, input_var)
        grad_filter_sym_comp = factory.derivative(cost, filter_var, input_var)

        grad_input_num_comp = factory.numeric_derivative(cost, input_var, 1.0, filter_var)
        grad_input_sym_comp = factory.derivative(cost, input_var, filter_var)

        conv_val = conv_comp(filter_val, input_val)
        conv_val_num = np.empty_like(conv_val)
        conv_val_num.fill(np.prod(cf.ax_f.lengths[:-1]))
        ng.testing.assert_allclose(conv_val, conv_val_num)

        grad_filter_num_val = grad_filter_num_comp(filter_val, input_val)
        grad_filter_sym_val = grad_filter_sym_comp(filter_val, input_val)
        ng.testing.assert_allclose(grad_filter_num_val, grad_filter_sym_val)

        grad_input_num_val = grad_input_num_comp(input_val, filter_val)
        grad_input_sym_val = grad_input_sym_comp(input_val, filter_val)
        ng.testing.assert_allclose(grad_input_num_val, grad_input_sym_val)
Пример #5
0
def test_idempotent_axes_c():
    """
    Test test axes transformations with autodiff, case c, with broadcast,
    slice, cast and dim-shuffle
    """
    with ExecutorFactory() as ex:
        axes = ng.make_axes([ng.make_axis(3), ng.make_axis(1)])
        result_axes = [ng.make_axis(length=axis.length) for axis in axes]

        # variable
        w = ng.variable(axes, initial_value=np.ones((3, 1)))

        # broadcast l / r, introducing dummy length 1 axes
        l = ng.broadcast(w, axes)
        r = ng.broadcast(w, axes)

        # slice
        axes_slice = [slice(None, None, None), slice(None, None, None)]
        l_sliced = ng.tensor_slice(l, axes_slice)
        r_sliced = ng.tensor_slice(r, axes_slice)

        # cast r
        r_sliced_casted = ng.cast_axes(r_sliced, axes)

        # perform add
        result = ng.add(l_sliced, r_sliced_casted)

        # cast / dimshuffle
        result = ng.cast_axes(result, result_axes)
        result = ng.axes_with_order(result, result_axes)

        # cost and grad
        cost = ng.sum(result, reduction_axes=result.axes)
        grad = ng.deriv(cost, w)

        grad_comp = ex.executor(grad)
        cost_comp = ex.executor(cost)

        cost_comp_ng = cost_comp()
        grad_comp_ng = grad_comp()
        grad_comp_np = np.ones((3, 1)) * 2.
        assert cost_comp_ng == 6.0
        assert np.array_equal(grad_comp_ng, grad_comp_np)
Пример #6
0
def test_slice_nop():
    """
    slicing an axis shouldn't change the name
    """
    input_axes = ng.make_axes([ng.make_axis(1), ng.make_axis(1)])
    x = ng.variable(input_axes)

    s = ng.tensor_slice(x, [
        slice(None, None, None),
        slice(None, None, 1),
    ])

    assert s.axes[0] == x.axes[0]
    assert s.axes[1] == x.axes[1]
Пример #7
0
def Squeeze(onnx_node, ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    """Remove single-dimensional entries from the shape of a tensor."""
    data = ng_inputs[0]
    axes_to_squeeze = onnx_node.get_attribute_value('axes')

    if max(axes_to_squeeze) >= len(data.axes):
        raise ValueError(
            'Squeeze node (%s): `axes` attribute value %d is out of range.',
            onnx_node.name, max(axes_to_squeeze))

    slices = [
        0 if index in axes_to_squeeze else slice(None)
        for index, axis in enumerate(data.axes)
    ]

    return ng.tensor_slice(data, slices)
Пример #8
0
def make_pooling_op(onnx_node, ng_inputs, custom_pool_params=None):
    # type: (NodeWrapper, List[TensorOp], Dict) -> Op
    """
    Create an ngraph pooling Op based on an ONNX node.

    :param onnx_node: wrapped ONNX node for a pooling op
    :param ng_inputs: ngraph TensorOp input tensors
    :param custom_pool_params: optional pool_params overriding values based on onnx_node
    :return: ngraph pooling op
    """
    x = ng_inputs[0]

    if len(x.axes) == 4:  # 2D pooling
        # Reshape x axes from ONNX (N, C, H, W) to ngraph (C, D, H, W, N)
        x = reorder_axes(x, 'NCHW', 'CDHWN')
    elif len(x.axes) == 5:  # 3D pooling
        # Reshape x axes from ONNX (N, C, H, W, D) to ngraph (C, D, H, W, N)
        x = reorder_axes(x, 'NCHWD', 'CDHWN')
    else:
        raise NotImplementedError(
            '%s node (%s): only 2D and 3D pooling ops are supported.',
            onnx_node.op_type, onnx_node.name)

    pool_params = get_pool_params(onnx_node)
    if custom_pool_params:
        pool_params.update(custom_pool_params)

    output_axes = make_pool_output_axes(x, pool_params)

    ng_op = ng.pooling(pool_params, x, output_axes)

    # ONNX output should have axes in the order N, C, H, W, D
    ng_op = reorder_axes(ng_op, 'CDHWN', 'NCHWD')

    if len(ng_inputs[0].axes
           ) == 4:  # 2D convolution, slice away the D axis from output
        ng_op = ng.tensor_slice(ng_op, [
            slice(None), slice(None),
            slice(None), slice(None), 0
        ])

    return ng_op
Пример #9
0
def test_specific_slice_deriv():
    #
    with ExecutorFactory() as ex:
        A = ng.make_axis(name='A', length=3)
        B = ng.make_axis(name='B', length=4)
        np_shape = (A.length, B.length)
        x_np = np.empty(np_shape, dtype=np.float32)
        for i in range(A.length):
            for j in range(B.length):
                x_np[i, j] = 10 * i + j
        x_ng = ng.persistent_tensor([A, B], initial_value=x_np)
        for i in range(A.length):
            for j in range(B.length):
                slice = ng.tensor_slice(x_ng, (i, j))
                dslice_dx = ng.deriv(slice, x_ng)
                dslice_dx_fun = ex.executor(dslice_dx)
                dslice_dx_val = dslice_dx_fun()
                dslice_dx_np = np.zeros_like(x_np)
                dslice_dx_np[i, j] = 1
                ng.testing.assert_allclose(dslice_dx_val, dslice_dx_np)
Пример #10
0
def squeeze_axes(inputs):
    """
    Removes axes with length of 1 for each tensor.

    Arguments:
        inputs: List of inputs to be sliced.

    Returns:
        Sliced inputs.
    """
    sliced_inputs = []
    for i in inputs:
        ones = []
        for axis in i.axes:
            if axis.length == 1:
                ones.append(0)
            else:
                ones.append(slice(None))
        sliced_inputs.append(ng.tensor_slice(i, ones))
    return sliced_inputs
Пример #11
0
def Split(onnx_node,
          ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Tuple[Op]
    """Split a tensor into a list of tensors."""
    data = ng_inputs[0]
    count_outputs = len(onnx_node.get_output_names())
    axis_to_split = onnx_node.get_attribute_value('axis')
    if axis_to_split < 0:
        axis_to_split = len(data.axes) + axis_to_split
    len_axis_to_split = data.axes[axis_to_split].length
    len_parts = onnx_node.get_attribute_value('split')

    if not len_parts:
        if len_axis_to_split % count_outputs:
            raise ValueError(
                'Split node (%s): Tensor cannot be split into %d equal parts, along '
                'axis of length %d', onnx_node.name, count_outputs,
                len_axis_to_split)
        len_parts = [int(len_axis_to_split / count_outputs)] * count_outputs

    outputs = []
    start_index = 0
    for len_part in len_parts:
        end_index = start_index + len_part
        output_axes = [
            ng.make_axis(length=len_part, name=data.axes[i].name)
            if i == axis_to_split else data.axes[i]
            for i in range(len(data.axes))
        ]
        slices = [
            slice(start_index, end_index)
            if i == axis_to_split else slice(None)
            for i in range(len(data.axes))
        ]
        outputs.append(
            ng.tensor_slice(data, slices, axes=ng.make_axes(output_axes)))
        start_index = end_index

    return tuple(outputs)
Пример #12
0
def test_slice(transformer_factory):
    """TODO."""

    C = ng.make_axis()
    D = ng.make_axis()

    tests = [{
        'tensor': [[1, 3], [2, 5]],
        'tensor_axes': (C, D),
        'slice': [0, 1],
        'sliced_axes': (),
        'axes_lengths': {
            C: 2,
            D: 2
        },
        'expected': 3
    }, {
        'tensor': [[1, 3], [2, 5]],
        'tensor_axes': (C, D),
        'slice': [slice(None), 0],
        'sliced_axes': (C, ),
        'axes_lengths': {
            C: 2,
            D: 2
        },
        'expected': [1, 2]
    }, {
        'tensor': [[1, 3], [2, 5]],
        'tensor_axes': (C, D),
        'slice': [1, slice(None)],
        'sliced_axes': (D, ),
        'axes_lengths': {
            C: 2,
            D: 2
        },
        'expected': [2, 5]
    }, {
        'tensor': [[1, 4, 5], [2, 5, 6]],
        'tensor_axes': (C, D),
        'slice': [1, slice(1, 3)],
        'sliced_axes': None,
        'axes_lengths': {
            C: 2,
            D: 3
        },
        'expected': [5, 6]
    }]

    for test in tests:
        with ExecutorFactory() as ex:
            for axis, length in test['axes_lengths'].items():
                axis.length = length
            tensor_axes = test['tensor_axes']

            tensor_np = np.array(test['tensor'], dtype='float32')
            tensor = ng.placeholder(tensor_axes)
            expected = np.array(test['expected'], dtype='float32')

            s = test['slice']
            s_axes = test['sliced_axes']

            sliced = ng.tensor_slice(tensor, s, s_axes)
            sliced_val_fun = ex.executor(sliced, tensor)

            num_deriv_fun = ex.numeric_derivative(sliced, tensor, delta)
            # Test backpropagation
            sym_deriv_fun = ex.derivative(sliced, tensor)

            sliced_val = sliced_val_fun(tensor_np)
            assert np.array_equal(sliced_val, expected)

            numeric_deriv = num_deriv_fun(tensor_np)
            sym_deriv = sym_deriv_fun(tensor_np)

            assert ng.testing.allclose(numeric_deriv,
                                       sym_deriv,
                                       rtol=rtol,
                                       atol=atol)
Пример #13
0
def make_convolution_op(onnx_node, ng_inputs, transpose=False):
    # type: (NodeWrapper, List[TensorOp], bool) -> Op
    """
    Create an ngraph convolution or deconvolution Op based on an ONNX node.

    :param onnx_node: wrapped ONNX node for Conv of ConvTranspose op
    :param ng_inputs: ngraph TensorOp input tensors
    :param transpose: should this be a transposed convolution?
    :return: ngraph Op for convolution or deconvolution
    """
    if len(ng_inputs) == 3:
        x, weights, bias = ng_inputs
    elif len(ng_inputs) == 2:
        x, weights = ng_inputs
        bias = ng.constant(0)
    else:
        raise ValueError(
            'Conv node (%s): unexpected number of input values: %d.',
            onnx_node.name, len(ng_inputs))

    # Reorder x axes from ONNX convention (N, C, H, W, D) to ngraph (C, D, H, W, N)
    # Reorder weights axes from ONNX (K, J, R, S, T) to ngraph (J, T, R, S, K)
    # Axis names follow https://ngraph.nervanasys.com/index.html/axes.html
    if len(x.axes) == 4:  # 2D convolution
        x = reorder_axes(x, 'NCHW', 'CDHWN')
        weights = reorder_axes(weights, 'KJRS', 'JTRSK')
    elif len(x.axes) == 5:  # 3D convolution
        x = reorder_axes(x, 'NCHWD', 'CDHWN')
        weights = reorder_axes(weights, 'KJRST', 'JTRSK')
    else:
        raise NotImplementedError(
            'Conv node (%s): only 2D and 3D convolutions are supported.',
            onnx_node.name)

    groups = onnx_node.get_attribute_value('group', 1)
    if groups != 1:
        raise NotImplementedError(
            'Conv node (%s): `group` attribute value %d not supported.',
            onnx_node.name, groups)

    # Prepare ngraph convolution operation
    conv_params = get_conv_params(onnx_node)
    output_axes = make_conv_output_axes(x, weights, conv_params)

    if transpose:
        conv = ng.deconvolution(conv_params, x, weights, axes=output_axes)

    else:
        conv = ng.convolution(conv_params, x, weights, axes=output_axes)

    conv = cast_to_pos_axes(conv) + bias

    # ONNX output should have axes in the order N, C, H, W, D
    conv = reorder_axes(conv, 'CDHWN', 'NCHWD')

    if len(ng_inputs[0].axes
           ) == 4:  # 2D convolution, slice away the D axis from output
        conv = ng.tensor_slice(conv, [
            slice(None), slice(None),
            slice(None), slice(None), 0
        ])

    return conv
Пример #14
0
    def MaxPool(self, tf_node, inputs):
        """
        Performs the max pooling on the input.

        Arguments:
            tf_node: NodeDef object, the tensorflow node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the tensorflow node.

        Inputs to tf_node:
            input

        TODO: assume default tensorflow layout NHWC, RSCK,
              need to support NCHW as well
              need to clean up / merge with conv2d

        Axes:
                      Tensorflow          Ngraph
            in       (N, H, W, C)     (C, D, H, W, N)
            out      (N, P, Q, K)     (K, M, P, Q, N)

        Notes on output shape:
            https://www.tensorflow.org/api_docs/python/nn.html#convolution
        """
        image = inputs[0]

        # TODO: currently NHWC only
        assert tf_node.attr['data_format'].s.decode("ascii") == "NHWC"

        # new axes
        C, D, H, W, K, M, P, Q = [ng.make_axis() for _ in range(8)]
        N = ng.make_axis(name='N')
        D.length, M.length = 1, 1  # only supports 2D conv for now

        # tf's input axes
        ax_i_tf = ng.make_axes([N, H, W, C])
        ax_i_tf.set_shape(image.axes.lengths)

        # ksize params
        tf_ksize = [int(s) for s in list(tf_node.attr['ksize'].list.i)]
        if len(tf_ksize) != 4:
            raise ValueError("Length of ksize my be 4.")
        if tf_ksize[0] != 1:
            raise NotImplementedError('Ksize on batch axis (N) must be 1.')
        if tf_ksize[3] != 1:
            raise NotImplementedError('Ksize on channel axis (C) must be 1.'
                                      'Cross map pooling to be implemented.')
        R_length, S_length = tf_ksize[1:3]
        T_length = J_length = 1

        # strides params
        tf_strides = [int(s) for s in list(tf_node.attr['strides'].list.i)]
        if len(tf_strides) != 4:
            raise ValueError("Length of strides my be 4.")
        if tf_strides[0] != 1:
            raise NotImplementedError('Strides on batch axis (N) must be 1.')
        if tf_strides[3] != 1:
            raise NotImplementedError('Strides on channel axis (C) must be 1.')
        str_h, str_w = tf_strides[1], tf_strides[2]

        # padding params
        padding = tf_node.attr['padding'].s.decode("ascii")
        pad_t, pad_b, pad_l, pad_r = common_conv2d_pool_padding(
            image.axes.lengths, (R_length, S_length, C.length, C.length),
            tf_strides, padding)
        if pad_t != pad_b or pad_l != pad_r:
            raise NotImplementedError("Requires symmetric padding in ngraph:"
                                      "pad_t(%s) == pad_b(%s) and"
                                      "pad_l(%s) == pad_r(%s)" %
                                      (pad_t, pad_b, pad_l, pad_r))
        # pooling params
        params = dict(op='max',
                      pad_d=0, pad_h=pad_t, pad_w=pad_l, pad_c=0,
                      str_d=1, str_h=str_h, str_w=str_w, str_c=1,
                      J=J_length, T=T_length, R=R_length, S=S_length)

        # tf's output axes
        ax_o_tf = ng.make_axes([N, P, Q, K])
        ax_o_tf.set_shape(common_conv2d_pool_output_shape(image.axes.lengths,
                                                          (R_length, S_length,
                                                           C.length, C.length),
                                                          tf_strides, padding))

        # ngraph's i, f, o axes
        ax_i = ng.make_axes([C, D, H, W, N])
        ax_o = ng.make_axes([K, M, P, Q, N])

        # image NHWC -> CDHWN
        image = ng.cast_axes(image, ng.make_axes([N, H, W, C]))
        image = ng.expand_dims(image, D, 1)  # NHWC -> NDHWC
        image = ng.axes_with_order(image, ax_i)  # NDHWC -> CDHWN

        # pooling
        output = ng.pooling(params, image, axes=ax_o)

        # output KMPQN -> NPQK
        # KMPQN -> NMPQK
        output = ng.axes_with_order(output, ng.make_axes(
            [N, M, P, Q, K]))
        # NMPQK -> NPQK
        output = ng.tensor_slice(output, [slice(None), 0, slice(None),
                                          slice(None), slice(None)])

        return output
Пример #15
0
    def Conv2D(self, tf_node, inputs):
        """
        Computes a 2-D convolution given 4D input and filter tensors.

        Arguments:
            tf_node: NodeDef object, the tensorflow node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the tensorflow node.

        Inputs to tf_node:
            input, filter

        TODO: assume default tensorflow layout NHWC, RSCK,
              need to support NCHW as well
              need to clean up / merge with maxpool

        Axes:
                      Tensorflow          Ngraph
            in       (N, H, W, C)     (C, D, H, W, N)
            filter   (R, S, C, K)     (C, T, R, S, K)
            out      (N, P, Q, K)     (K, M, P, Q, N)

        Notes on output shape:
            https://www.tensorflow.org/api_docs/python/nn.html#convolution
        """
        image, weight = inputs

        # TODO: currently NHWC only
        if tf_node.attr['data_format'].s.decode("ascii") != "NHWC":
            raise NotImplementedError("Only supports NHWC import for now.")

        # check in_C == f_C
        if image.axes.lengths[3] != weight.axes.lengths[2]:
            raise ValueError("Image's C dimension (%s) must be equal to "
                             "filter's C dimension (%s)."
                             % (image.axes.lengths[3], weight.axes.lengths[2]))

        # strides params
        tf_strides = [int(s) for s in list(tf_node.attr['strides'].list.i)]
        if len(tf_strides) != 4:
            raise ValueError("Length of strides my be 4.")
        if tf_strides[0] != 1:
            raise NotImplementedError('Strides on batch axis (N) must be 1.')
        if tf_strides[3] != 1:
            raise NotImplementedError('Strides on channel axis (C) must be 1.')
        str_h, str_w = tf_strides[1], tf_strides[2]

        # padding params
        padding = tf_node.attr['padding'].s.decode("ascii")
        pad_t, pad_b, pad_l, pad_r = common_conv2d_pool_padding(
            image.axes.lengths, weight.axes.lengths, tf_strides, padding)
        if pad_t != pad_b or pad_l != pad_r:
            raise NotImplementedError("Requires symmetric padding in ngraph:"
                                      "pad_t(%s) == pad_b(%s) and"
                                      "pad_l(%s) == pad_r(%s)" %
                                      (pad_t, pad_b, pad_l, pad_r))

        # conv params
        params = dict(pad_d=0, pad_h=pad_t, pad_w=pad_l,
                      str_d=1, str_h=str_h, str_w=str_w,
                      dil_d=1, dil_h=1, dil_w=1)

        # new axes
        C, D, H, W, T, R, S, K, M, P, Q = [ng.make_axis() for _ in range(11)]
        N = ng.make_axis(name='N')
        D.length, T.length, M.length = 1, 1, 1  # only supports 2D conv for now

        # tf's i, f, o axes
        ax_i_tf = ng.make_axes([N, H, W, C])
        ax_f_tf = ng.make_axes([R, S, C, K])
        ax_o_tf = ng.make_axes([N, P, Q, K])
        ax_i_tf.set_shape(image.axes.lengths)
        ax_f_tf.set_shape(weight.axes.lengths)
        ax_o_tf.set_shape(common_conv2d_pool_output_shape(image.axes.lengths,
                                                          weight.axes.lengths,
                                                          tf_strides, padding))

        # ngraph's i, f, o axes
        ax_i = ng.make_axes([C, D, H, W, N])
        ax_f = ng.make_axes([C, T, R, S, K])
        ax_o = ng.make_axes([K, M, P, Q, N])

        # image NHWC -> CDHWN
        image = ng.cast_axes(image, ng.make_axes([N, H, W, C]))
        image = ng.expand_dims(image, D, 1)  # NHWC -> NDHWC
        image = ng.axes_with_order(image, ax_i)  # NDHWC -> CDHWN

        # weights RSCK -> CTRSK
        weight = ng.cast_axes(weight, ng.make_axes([R, S, C, K]))
        weight = ng.expand_dims(weight, T, 0)  # RSCK -> TRSCK
        weight = ng.axes_with_order(weight, ax_f)  # TRSCK -> CTRSK

        # convolution
        output = ng.convolution(params, image, weight, axes=ax_o)

        # output KMPQN -> NPQK
        # KMPQN -> NMPQK
        output = ng.axes_with_order(output, ng.make_axes([N, M, P, Q, K]))
        # NMPQK -> NPQK
        output = ng.tensor_slice(output, [slice(None), 0, slice(None),
                                          slice(None), slice(None)])

        return output
    def __call__(self,
                 in_obj,
                 channel_axes="C",
                 spatial_axes=("D", "H", "W"),
                 **kwargs):
        """
        Arguments:
            in_obj (Op): Input op
            channel_axes (str): name of the expected channel axis type - defaults to "C"
            spatial_axes (tuple): names of expected depth, height and width axis types - defaults
                                  to "D", "H", and "W"
        """
        if isinstance(spatial_axes, dict):
            spatial_axes = tuple(
                spatial_axes.get(name, name) for name in ("D", "H", "W"))
        elif isinstance(spatial_axes, tuple):
            if len(spatial_axes) < 3:
                raise ValueError(
                    "spatial_axes must have length 3 (e.g. ('D', 'H', 'W'))")
            spatial_axes = tuple(
                name if name else default
                for name, default in zip(spatial_axes, ("D", "H", "W")))

        orig_axes = in_obj.axes
        in_obj = reorder_spatial_axes(in_obj, channel_axes, spatial_axes)
        channel_axes = in_obj.axes.get_by_names(channel_axes)
        spatial_axes = in_obj.axes.get_by_names(*spatial_axes)

        filter_axes = self._filter_axes(channel_axes, spatial_axes)

        # mark 'K' as a shadow axis for the initializers.
        axes_map = shadow_axes_map(filter_axes.find_by_name('K'))
        filter_axes = ng.make_axes([
            axis if axis.name != 'K' else list(axes_map.keys())[0]
            for axis in filter_axes
        ])

        if not self.initialized:
            if not self.weight_norm:
                self.W = ng.variable(axes=filter_axes,
                                     initial_value=self.init,
                                     metadata={
                                         "label": LABELS["weight"]
                                     }).named("W")
            else:
                self.v = ng.variable(axes=filter_axes,
                                     initial_value=self.init,
                                     metadata={
                                         "label": LABELS["weight"]
                                     }).named("v")
                out_axes = ng.make_axes(
                    [filter_axes.get_by_names("K__NG_SHADOW")])
                v_norm = ng.mean(ng.square(self.v), out_axes=out_axes)
                self.g = ng.variable(axes=out_axes,
                                     initial_value=self.init,
                                     metadata={
                                         "label": LABELS["weight"]
                                     }).named("g")
                self.W = self.g * self.v * ng.reciprocal(
                    ng.sqrt(v_norm + 1e-3))
        else:
            if filter_axes != self.W.axes:
                raise ValueError(
                    ("{layer_name} layer has already been initialized with an "
                     "input object which has resulted in filter axes: "
                     "{existing_filter_axes}. This new input object has axes: "
                     "{input_axes}, which implies the need for filter axes: "
                     "{new_filter_axes} which are different than the existing "
                     "filter axes.").format(
                         layer_name=self.name,
                         existing_filter_axes=self.W.axes,
                         input_axes=in_obj.axes,
                         new_filter_axes=filter_axes,
                     ))

        output = ng.map_roles(
            self._conv_op(in_obj, channel_axes, spatial_axes), axes_map)
        # Reorder the output to match the input order
        output_axis_order = ng.make_axes(
            [output.axes.find_by_name(ax.name)[0] for ax in orig_axes])
        # Remove introduced axes. If their length is > 1, then perhaps they should be kept
        slices = [
            0 if (ax not in orig_axes) and ax.length == 1 else slice(None)
            for ax in output.axes
        ]
        output = ng.tensor_slice(output, slices)
        # New axes with length > 1 may have been introduced. Add them to the end.
        output_axis_order = output_axis_order | output.axes
        return ng.axes_with_order(output, output_axis_order)
out_axis = ng.make_axis(length=n_classes, name="Fo")

in_axes = ng.make_axes([batch_axis, feature_axis, time_axis])
out_axes = ng.make_axes([batch_axis, out_axis])

# Build placeholders for the created axes
inputs = dict(X=ng.placeholder(in_axes),
              y=ng.placeholder(out_axes),
              iteration=ng.placeholder(axes=()))

# define model
if args.modeltype == "TCN":
    # take only the last timepoint of output sequence to predict sum
    last_timepoint = [
        lambda op: ng.tensor_slice(op, [
            slice(seq_len - 1, seq_len, 1) if ax.name == "W" else slice(None)
            for ax in op.axes
        ])
    ]
    affine_layer = Affine(axes=out_axis,
                          weight_init=GaussianInit(0, 0.01),
                          activation=Identity())

    model = Sequential(
        [lambda op: ng.map_roles(op, {
            'REC': 'W',
            'F': 'C'
        })] +
        tcn(n_features, hidden_sizes, kernel_size=kernel_size,
            dropout=dropout).layers + last_timepoint + [affine_layer])
elif args.modeltype == "LSTM":
    model = recurrent_model.define_model(out_axis,
Пример #18
0
    def Conv(self, c2_op, inputs):
        """
        Computes a 2-D convolution given 4D input and filter tensors.

        Arguments:
            c2_op: NodeDef object, the caffe2 node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the caffe2 node.

        Inputs to c2_op:
            input, wegiths, filter

        Supports caffe2's layout NHWC and NCHW as well.
        """
        X, W, bias = inputs

        order = [val.s for val in c2_op.arg if val.name == "order"]
        if 1 != len(order):
            raise ValueError("Multiple order values in convolution")
        order = order[0]

        if order not in ("NHWC", "NCHW"):
            raise NotImplementedError("Unsupported order in convolution: {}",
                                      order)

        # set input axes shape
        ax_N = ng.make_axis(name='N')
        ax_C = ng.make_axis()
        ax_D = ng.make_axis(length=1)
        ax_H = ng.make_axis()
        ax_W = ng.make_axis()

        # set kernel axes shape
        ax_kernel_D = ng.make_axis(length=1)
        ax_kernel_H = ng.make_axis()
        ax_kernel_W = ng.make_axis()
        ax_kernel_ofm = ng.make_axis()

        # create placeholders for output axes
        oC = ng.make_axis(name='C')
        oD = ng.make_axis(name='D', length=1)
        oH = ng.make_axis(name='H')
        oW = ng.make_axis(name='W')

        axes_order = {
            'NCHW': {
                'X': [ax_N, ax_C, ax_H, ax_W],
                'W': [ax_kernel_ofm, ax_C, ax_kernel_H, ax_kernel_W]
            },
            'NHWC': {
                'X': [ax_N, ax_H, ax_W, ax_C],
                'W': [ax_kernel_ofm, ax_kernel_H, ax_kernel_W, ax_C]
            },
        }

        ng.make_axes(axes_order[order]['X']).set_shape(X.axes.lengths)
        ng.make_axes(axes_order[order]['W']).set_shape(W.axes.lengths)

        if 1 != len(bias.axes):
            raise ValueError("Bias's must be 1D.")
        if ax_kernel_ofm.length != bias.axes.lengths[0]:
            raise ValueError(
                "Bias's length must equal to number of output feature maps.")

        # strides params
        stride_size = [int(val.i) for val in c2_op.arg if val.name == "stride"]
        if len(stride_size) != 1:
            raise ValueError("Stride size must be scalar value")
        str_h = str_w = stride_size[0]

        # padding params
        pad_t, pad_b, pad_l, pad_r = \
            _c2_padding(c2_op,
                        in_NHWC=[ax_N.length, ax_H.length, ax_W.length, ax_C.length],
                        kernel_HWIO=[ax_kernel_H.length, ax_kernel_W.length,
                                     ax_C.length, ax_kernel_ofm.length],
                        stride_NHWC=[1, str_h, str_w, 1])

        if pad_t != pad_b or pad_l != pad_r:
            raise NotImplementedError("Requires symmetric padding in ngraph:"
                                      "pad_t(%s) == pad_b(%s) and"
                                      "pad_l(%s) == pad_r(%s)" %
                                      (pad_t, pad_b, pad_l, pad_r))

        # conv params
        params = dict(pad_d=0,
                      pad_h=pad_t,
                      pad_w=pad_l,
                      str_d=1,
                      str_h=str_h,
                      str_w=str_w,
                      dil_d=1,
                      dil_h=1,
                      dil_w=1)

        # input, weight, output axes
        internal_ax_dict = {
            'X':
            ng.make_axes([ax_C, ax_D, ax_H, ax_W, ax_N]),
            'W':
            ng.make_axes(
                [ax_C, ax_kernel_D, ax_kernel_H, ax_kernel_W, ax_kernel_ofm])
        }

        oC.length = ax_kernel_ofm.length
        oH.length = output_dim(ax_H.length, ax_kernel_H.length,
                               params['pad_h'], params['str_h'])
        oW.length = output_dim(ax_W.length, ax_kernel_W.length,
                               params['pad_w'], params['str_w'])
        internal_ax_dict['Y'] = ng.make_axes([oC, oD, oH, oW, ax_N])

        # broadcast input / filter axes
        # flow for NHWC order:                   |  flow for NCHW order:
        # input:                                 |  input:
        #   expand dims: NHWC -> NDHWC           |    expand dims: NCHW -> NDCHW
        #   reorder:     NDHWC -> CDHWN          |    reorder:     NDCHW -> CDHWN
        # weights:                               |  weights:
        #   expand dims: (ofm)HWC -> D(ofm)HWC   |    expand dims: (ofm)CHWC -> D(ofm)CHW
        #   reorder:     D(ofm)HWC -> CDHW(ofm)  |    reorder:     D(ofm)CHW -> CDHW(ofm)

        X = ng.cast_axes(X, ng.make_axes(axes_order[order]['X']))
        X = ng.expand_dims(X, ax_D, 1)
        X = ng.axes_with_order(X, axes=internal_ax_dict['X'])
        W = ng.cast_axes(W, ng.make_axes(axes_order[order]['W']))
        W = ng.expand_dims(W, ax_kernel_D, 0)
        W = ng.axes_with_order(W, axes=internal_ax_dict['W'])

        # convolution
        Y = ng.convolution(params, X, W, axes=internal_ax_dict['Y'])

        # cast back to proper format
        Y = ng.broadcast(Y, ng.make_axes([ax_N, oD, oH, oW, oC])) if "NHWC" == order \
            else ng.broadcast(Y, ng.make_axes([ax_N, oD, oC, oH, oW]))  # NCHW

        # slice away the oD
        out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)]
        Y = ng.tensor_slice(Y, out_slicing)

        def _conv_bias_add(c2_op, inputs):
            X, bias = inputs
            bias = ng.cast_axes(bias,
                                axes=ng.make_axes(
                                    [X.axes[1 if 'NCHW' == order else 3]]))
            Y = ng.Add(X, bias)
            return Y

        return _conv_bias_add(c2_op, [Y, bias])
Пример #19
0
    def Pool(self, c2_op, inputs):
        """
        Performs max or average pooling on the input.

        Arguments:
            c2_op: NodeDef object, the tensorflow node to convert.
            inputs: List of ngraph Ops as inputs to this node.

        Returns:
            A ngraph Op corresponding to the c2_op node.

        Inputs to c2_op:
            input
        """
        supported_pooling = {'MaxPool': 'max', 'AveragePool': 'avg'}

        image = inputs[0]

        # TODO: we assume NCHW, make some assert here?

        # set input axes shape
        ax_N = ng.make_axis(name='N')
        ax_C = ng.make_axis()
        ax_D = ng.make_axis(length=1)
        ax_H = ng.make_axis()
        ax_W = ng.make_axis()
        ng.make_axes([ax_N, ax_C, ax_H, ax_W]).set_shape(image.axes.lengths)

        # create placeholders for output axes
        oC = ng.make_axis(name='C')
        oD = ng.make_axis(length=1, name='D')
        oH = ng.make_axis(name='H')
        oW = ng.make_axis(name='W')

        # spatial kernel size
        kernel_size = [int(val.i) for val in c2_op.arg if val.name == "kernel"]
        if len(kernel_size) != 1:
            raise ValueError("Kernel size must be scalar value")
        # kernel is square
        kernel_h = kernel_w = kernel_size[0]
        kernel_d = kernel_c = 1

        # strides params
        stride_size = [int(val.i) for val in c2_op.arg if val.name == "stride"]
        if len(stride_size) != 1:
            raise ValueError("Stride size must be scalar value")
        stride_h = stride_w = stride_size[0]

        # padding params
        pad_t, pad_b, pad_l, pad_r = \
            _c2_padding(c2_op,
                        in_NHWC=[ax_N.length, ax_H.length, ax_W.length, ax_C.length],
                        kernel_HWIO=[kernel_h, kernel_w, ax_C.length, ax_C.length],
                        stride_NHWC=[1, stride_h, stride_w, 1])
        if pad_t != pad_b or pad_l != pad_r:
            raise NotImplementedError("Requires symmetric padding in ngraph:"
                                      "pad_t(%s) == pad_b(%s) and"
                                      "pad_l(%s) == pad_r(%s)" %
                                      (pad_t, pad_b, pad_l, pad_r))

        # pooling params
        params = dict(op=supported_pooling[c2_op.type],
                      pad_d=0,
                      pad_h=pad_t,
                      pad_w=pad_l,
                      pad_c=0,
                      str_d=1,
                      str_h=stride_h,
                      str_w=stride_w,
                      str_c=1,
                      J=kernel_c,
                      T=kernel_d,
                      R=kernel_h,
                      S=kernel_w)

        # i, o axes
        oC.length = output_dim(ax_C.length, kernel_c, params['pad_c'],
                               params['str_c'])
        oD.length = output_dim(ax_D.length, kernel_d, params['pad_d'],
                               params['str_d'])
        oH.length = output_dim(ax_H.length, kernel_h, params['pad_h'],
                               params['str_h'])
        oW.length = output_dim(ax_W.length, kernel_w, params['pad_w'],
                               params['str_w'])
        ax_i = ng.make_axes([ax_C, ax_D, ax_H, ax_W, ax_N])
        ax_o = ng.make_axes([oC, oD, oH, oW, ax_N])

        # broadcast input / filter axes
        image = ng.cast_axes(image, ng.make_axes([ax_N, ax_C, ax_H, ax_W]))
        image = ng.expand_dims(image, ax_D, 1)  # NCHW -> NDCHW
        image = ng.axes_with_order(image, axes=ax_i)  # NDCHW -> CDHWN

        # pooling
        output = ng.pooling(params, image, axes=ax_o)

        # cast back to NDCHW
        output = ng.broadcast(output, ng.make_axes([ax_N, oD, oC, oH, oW]))

        # slice away the oD
        out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)]
        output = ng.tensor_slice(output, out_slicing)

        return output
Пример #20
0
batch_cost = ng.sequential([
    optimizer(train_loss_main + 0.4 * train_loss_aux),
    ng.mean(train_loss_main, out_axes=())
])

train_computation = ng.computation([batch_cost], 'all')

# Build the computations for inference (evaluation)
with Layer.inference_mode_on():
    inference_prob = inception.seq2(inception.seq1(inputs['image']))
    slices = [
        0 if cx.name in ("H", "W") else slice(None)
        for cx in inference_prob.axes
    ]
    inference_prob = ng.tensor_slice(inference_prob, slices)
    inference_prob = ng.map_roles(inference_prob, {"C": "Y"})
    errors = ng.not_equal(ng.argmax(inference_prob, out_axes=[ax.N]),
                          inputs['label'])
    eval_loss = ng.cross_entropy_multi(inference_prob,
                                       y_onehot,
                                       enable_softmax_opt=False)
    eval_loss_names = ['cross_ent_loss', 'misclass', 'predictions']
    eval_computation = ng.computation([eval_loss, errors, inference_prob],
                                      "all")

with closing(ngt.make_transformer()) as transformer:
    train_function = transformer.add_computation(train_computation)
    eval_function = transformer.add_computation(eval_computation)

    if args.no_progress_bar:
Пример #21
0
def test_conv_flatten_deriv(transformer_factory):
    """
    Test deriv of conv followed by flatten
    """

    # set shape
    # NOTE: N must be >= 4 for GPU, but for CPU this could be decreased to
    # speed up the test
    N = 4
    C, D, H, W = (3, 1, 28, 28)
    T, R, S, K = (1, 5, 5, 8)

    params = dict(pad_d=0,
                  pad_h=0,
                  pad_w=0,
                  str_d=1,
                  str_h=1,
                  str_w=1,
                  dil_d=1,
                  dil_h=1,
                  dil_w=1)

    # i, f, o axes
    ax_i = ng.make_axes([ax.C, ax.D, ax.H, ax.W, ax.N])
    ax_f = ng.make_axes([ax.C, ax.T, ax.R, ax.S, ax.K])
    ax_o = ng.make_axes([
        ng.make_axis(roles=[ar.features_input]).named('C'),
        ng.make_axis(roles=[ar.features_0]).named('D'),
        ng.make_axis(roles=[ar.features_1]).named('H'),
        ng.make_axis(roles=[ar.features_2]).named('W'), ax.N
    ])

    ax_i.set_shape((C, D, H, W, N))
    ax_f.set_shape((C, T, R, S, K))
    ax_o.set_shape((K, D - T + 1, H - R + 1, W - S + 1, N))
    axes_rsck = ng.make_axes([ax.R, ax.S, ax.C, ax.K])
    axes_rsck_prime = ng.make_axes([
        ng.make_axis(axis.length).named(axis.name + 'p') for axis in axes_rsck
    ])
    axes_nmpqk = ng.make_axes([ax_o[-1], ax_o[1], ax_o[2], ax_o[3], ax_o[0]])

    # broadcast input / filter axes
    input_var = ng.variable(ax_i).named('input')
    input_var.input = True
    input_val = np.ones(input_var.axes.lengths)

    filter_rsck_prime = ng.variable(axes_rsck_prime)
    filter_var = filter_rsck_prime
    filter_rsck = ng.cast_axes(filter_rsck_prime, axes_rsck)
    filter_trsck = ng.expand_dims(filter_rsck, ax.T, 0)
    filter_ctrsk = ng.axes_with_order(filter_trsck, axes=ax_f)

    # convolution
    output_kmpqn = ng.convolution(params, input_var, filter_ctrsk, axes=ax_o)
    output_nmpqk = ng.axes_with_order(output_kmpqn, axes=axes_nmpqk)

    # slice away the oD
    out_slicing = [slice(None), 0, slice(None), slice(None), slice(None)]
    output_npqk = ng.tensor_slice(output_nmpqk, out_slicing)

    output = ng.flatten_at(output_npqk, idx=1)

    # cost and grad
    cost = ng.sum(output, out_axes=())

    filter_var.input = True
    filter_var.named('filter')
    filter_val = np.ones(filter_var.axes.lengths)

    with ExecutorFactory() as factory:

        conv_comp = factory.executor(output, filter_var, input_var)
        grad_filter_num_comp = factory.numeric_derivative(
            cost, filter_var, 1.0, input_var)
        grad_filter_sym_comp = factory.derivative(cost, filter_var, input_var)

        grad_input_num_comp = factory.numeric_derivative(
            cost, input_var, 1.0, filter_var)
        grad_input_sym_comp = factory.derivative(cost, input_var, filter_var)

        conv_val = conv_comp(filter_val, input_val)
        conv_val_num = np.empty_like(conv_val)
        conv_val_num.fill(C * T * R * S)
        assert ng.testing.allclose(conv_val, conv_val_num)

        grad_filter_num_val = grad_filter_num_comp(filter_val, input_val)
        grad_filter_sym_val = grad_filter_sym_comp(filter_val, input_val)
        assert ng.testing.allclose(grad_filter_num_val, grad_filter_sym_val)

        grad_input_num_val = grad_input_num_comp(input_val, filter_val)
        grad_input_sym_val = grad_input_sym_comp(input_val, filter_val)
        assert ng.testing.allclose(grad_input_num_val, grad_input_sym_val)
def broadcast_to(x, out_shape):
    """
    Broadcast tensor x to out_shape.

    Args:
        x: tensor to be broadcasted
        out_shape: tuple of the targeted shape

    Example:

         [3][2][1][0]    [4][3][2][1][0]
    from (5, 1, 2, 1) to (4, 5, 1, 2, 3)

    # step 1:
                                        [3][2][1]
    collapse 1 that will be broadcasted (5, 1, 2)

    # step 2:
             [4][0]          [4][0][3][2][1]
    add with (4, 3), becomes (4, 3, 5, 1, 2)

    # step 3:
               [4][3][2][1][0]
    reorder to (4, 5, 1, 2, 3)

    Returns:
        x broadcasted to outs_shape
    """
    if not is_compatible_broadcast_shape(x.axes.lengths, out_shape):
        raise ValueError(
            "x's shape {} is not broadcastable to out_shape {}".format(
                x.axes.lengths, out_shape))
    x_ndims = len(x.axes)

    if x_ndims == 0:
        # special case: x'shape is same as out_shape
        return x

    elif x.axes.lengths == out_shape:
        # special case: scalar
        zero = ng.constant(0., axes=make_pos_axes(out_shape))
        return x + zero

    else:
        # collapse (collapse all dimension 1 axes that will be broadcasted)
        x_slice = []
        sliced_indices = []
        for index, (x_len, out_len) in enumerate(
                zip(x.axes.lengths, out_shape[-len(x.axes):])):
            if x_len == 1 and out_len != 1:
                x_slice.append(0)
                sliced_indices.append(index)
            else:
                x_slice.append(slice(None))
        x = ng.tensor_slice(x, x_slice)

        # get the axes for the dummy zero
        zero_positions = [x_ndims - i - 1 for i in sliced_indices]
        zero_positions += list(range(x_ndims, len(out_shape)))
        zero_shape = [out_shape[-i - 1] for i in zero_positions]
        zero = ng.constant(0.,
                           axes=make_pos_axes(zero_shape,
                                              positions=zero_positions))

        # broadcast and reorder
        x = reorder_pos_axes(x + zero)

        return x
Пример #23
0
    def _element_wise_binary(self, ng_op, inputs, name=None):
        """
        Element-wise binary operation with broadcast.
        Args:
            ng_op: ngraph Op to be applied.
            inputs: List of ngraph Ops as inputs to this node.
            name: name of the ngraph op
        Returns:
            A ngraph Op corresponding to the element-wise binary op
        """
        # get inputs
        left, right = inputs

        # check if shape compatibility
        left_shape = left.axes.lengths
        right_shape = right.axes.lengths
        assert is_compatible_numpy_shape(left_shape, right_shape)

        if left_shape and right_shape and left_shape != right_shape:
            """
            Cast axes in numpy broadcast mapping rule
            1. introduce dummy length 1 axes to match left / right length
            2. keep maps for matching left / right / result axes
            3. slice left / right to remove length 1 axes if not both of them
               are length 1
            4. cast right to left by matching axes
            5. perform binary op
            6. cast and broadcast result
            """

            left_dim = len(left.axes)
            right_dim = len(right.axes)

            # pad left and right axis to be the same length, align right
            result_dim = max(left_dim, right_dim)
            left_axes_pad = [
                ng.make_axis(length=1) for _ in range(result_dim - left_dim)
            ] + list(left.axes)
            right_axes_pad = [
                ng.make_axis(length=1) for _ in range(result_dim - right_dim)
            ] + list(right.axes)
            result_axes = [
                ng.make_axis(length=max(l.length, r.length))
                for l, r in zip(left_axes_pad, right_axes_pad)
            ]

            # broadcast left / right, introducing dummy length 1 axes
            left = ng.broadcast(left, left_axes_pad)
            right = ng.broadcast(right, right_axes_pad)

            # make two-way map of lr matching axes and map for result axes
            lr_axes_map = dict()
            result_axes_map = dict()
            for l, r, re in zip(left.axes, right.axes, result_axes):
                lr_axes_map[l] = r
                lr_axes_map[r] = l
                result_axes_map[l] = re
                result_axes_map[r] = re

            # get left / right slice
            left_slice = []
            right_slice = []
            for l, r in zip(left.axes, right.axes):
                if l.length == 1 and r.length != 1:
                    left_slice.append(0)
                else:
                    left_slice.append(slice(None))
                if r.length == 1 and l.length != 1:
                    right_slice.append(0)
                else:
                    right_slice.append(slice(None))

            # perform slicing
            left_sliced = ng.tensor_slice(left, left_slice)
            right_sliced = ng.tensor_slice(right, right_slice)

            # now cast the right_sliced to left_sliced from the axis map
            right_casted_axes = []
            for r in right_sliced.axes:
                if r in lr_axes_map and lr_axes_map[r] in left_sliced.axes:
                    right_casted_axes.append(lr_axes_map[r])
                else:
                    right_casted_axes.append(r)
            right_sliced_casted = ng.cast_axes(right_sliced, right_casted_axes)

            # perform binary op
            result_op = ng_op(left_sliced, right_sliced_casted)

            # cast result axis and broadcast to full result axes
            trimmed_result_axes = [
                result_axes_map[re] for re in result_op.axes
            ]
            result_op = ng.cast_axes(result_op, trimmed_result_axes)
            result_op = ng.axes_with_order(result_op, axes=result_axes)

        elif left_shape == right_shape:
            # cast right axes to be the same as left
            right = ng.cast_axes(right, left.axes)
            result_op = ng_op(left, right).named(name)

        else:
            # no need for casting
            result_op = ng_op(left, right).named(name)

        # return op
        return result_op