コード例 #1
0
    def __call__(self, in_obj):

        in_axes = in_obj.axes
        if in_axes.channel_axis() is None:
            red_axes = ng.make_axes(in_axes.recurrent_axis()) + in_axes.batch_axes()
        else:
            red_axes = in_axes - in_axes.channel_axis()

        out_axes = in_axes - red_axes

        in_obj = ng.flatten(in_obj, out_axes | red_axes.flatten(force=True))
        if self.gamma is None:
            self.gvar = self.gvar or ng.persistent_tensor(axes=out_axes, initial_value=1.0)
            self.gmean = self.gmean or ng.persistent_tensor(axes=out_axes, initial_value=0.0)
            self.gamma = ng.variable(axes=out_axes,
                                     initial_value=self.init_gamma,
                                     scope=self.scope).named('gamma')
            self.beta = ng.variable(axes=out_axes,
                                    initial_value=self.init_beta,
                                    scope=self.scope).named('beta')

        xmean = ng.mean(in_obj, out_axes=out_axes)
        xvar = ng.variance(in_obj, out_axes=out_axes)

        if Layer.inference_mode:
            return ng.unflatten(self.gamma * ((in_obj - self.gmean) *
                                ng.reciprocal(ng.sqrt(self.gvar + self.eps))) + self.beta)
        else:
            return ng.sequential([
                ng.assign(self.gmean, self.gmean * self.rho + xmean * (1.0 - self.rho)),
                ng.assign(self.gvar, self.gvar * self.rho + xvar * (1.0 - self.rho)),
                ng.unflatten(self.gamma * ((in_obj - xmean) *
                             ng.reciprocal(ng.sqrt(xvar + self.eps))) + self.beta)
            ])
コード例 #2
0
ファイル: layer.py プロジェクト: kkasravi/ngraph
    def train_outputs(self, in_obj):
        in_axes = in_obj.axes.sample_axes()
        red_axes = ng.make_axes()
        if len(in_axes.role_axes(ar.features_input)) != 0:
            red_axes += in_axes.sample_axes() - in_axes.role_axes(
                ar.features_input)
        red_axes += in_obj.axes.batch_axes()
        out_axes = in_axes - red_axes

        self.gamma = self.gamma or ng.variable(
            axes=out_axes, initial_value=1.0).named('gamma')
        self.beta = self.beta or ng.variable(axes=out_axes,
                                             initial_value=0.0).named('beta')
        self.gvar = self.gvar or ng.persistent_tensor(axes=out_axes,
                                                      initial_value=1.0)
        self.gmean = self.gmean or ng.persistent_tensor(axes=out_axes,
                                                        initial_value=0.0)

        xmean = ng.mean(in_obj, reduction_axes=red_axes)
        xvar = ng.variance(in_obj, reduction_axes=red_axes)
        return ng.sequential([
            ng.assign(self.gmean,
                      self.gmean * self.rho + xmean * (1.0 - self.rho)),
            ng.assign(self.gvar,
                      self.gvar * self.rho + xvar * (1.0 - self.rho)),
            self.gamma * (in_obj - xmean) / ng.sqrt(xvar + self.eps) +
            self.beta
        ])
コード例 #3
0
    def __call__(self, cost_func):
        all_updates = []
        batch_cost = ng.sum(cost_func, out_axes=())
        batch_size = cost_func.axes.batch_axes()[0].length

        grads = [
            ng.deriv(batch_cost, v) / batch_size
            for v in batch_cost.variables()
        ]
        scale_factor = clip_gradient_norm(grads, batch_size,
                                          self.gradient_clip_norm)

        epsilon, decay = (self.epsilon, self.decay_rate)
        for i, (variable, grad) in enumerate(zip(batch_cost.variables(),
                                                 grads)):
            grad = clip_gradient_value(grad, self.gradient_clip_value)
            state = ng.persistent_tensor(axes=variable.axes, initial_value=0.)
            all_updates.append(
                ng.sequential([
                    ng.assign(state,
                              decay * state + (1.0 - decay) * ng.square(grad)),
                    ng.assign(
                        variable,
                        variable - ((scale_factor * grad * self.lrate) /
                                    (ng.sqrt(state + epsilon) + epsilon)))
                ]))

        return ng.doall(all_updates)
コード例 #4
0
ファイル: optimizer.py プロジェクト: psdurley/ngraph
def clip_gradient_norm(grad_list, clip_norm=None):
    """
    Returns a scaling factor to apply to the gradients.

    The scaling factor is computed such that the root mean squared
    average of the scaled gradients across all layers will be less than
    or equal to the provided clip_norm value. This factor is always <1, so
    never scales up the gradients.

    Arguments:
        param_list (list): List of layer parameters
        clip_norm (float, optional): Target norm for the gradients. If not provided
                                     the returned scale_factor will equal 1.


    Returns:
        Computed scale factor (float)
    """
    if clip_norm is None:
        return 1
    else:
        s = None
        for param in grad_list:
            term = ng.squared_L2(param, out_axes=None)
            if s is None:
                s = term
            else:
                s = s + term

        s = ng.sqrt(s)
        return clip_norm / ng.maximum(s, clip_norm)
コード例 #5
0
ファイル: ops_bridge.py プロジェクト: rsumner31/ngraph
def BatchNormalization(onnx_node,
                       ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    x, scale, bias, mean, var = ng_inputs

    is_test = onnx_node.get_attribute_value('is_test', 1)
    spatial = onnx_node.get_attribute_value('spatial', 1)
    epsilon = onnx_node.get_attribute_value('epsilon', 1e-3)
    # @TODO: Implement learning mode support
    # momentum = onnx_node.get_attribute_value('momentum', 0.99)

    if not is_test:
        raise NotImplementedError(
            'BatchNormalization node (%s): only `is_test` mode is currently '
            'supported.', onnx_node.name)
    if not spatial:
        raise NotImplementedError(
            'BatchNormalization node (%s): only `spatial` mode is currently '
            'supported.', onnx_node.name)

    if len(x.axes) == 5:
        x = rename_axes(x, 'NCHWD')
    else:
        x = rename_axes(x, 'NCHW')

    mean = rename_axes(mean, 'C')
    scale = rename_axes(scale, 'C')
    bias = rename_axes(bias, 'C')
    var = rename_axes(var, 'C')

    ng_op = ng.unflatten(scale *
                         ((x - mean) * ng.reciprocal(ng.sqrt(var + epsilon))) +
                         bias)

    return cast_to_pos_axes(ng_op)
コード例 #6
0
def BatchNormalization(onnx_node, ng_inputs):  # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode
    """Carry out batch normalization."""
    x, scale, bias, mean, var = ng_inputs

    is_test = onnx_node.get_attribute_value('is_test', 1)
    spatial = onnx_node.get_attribute_value('spatial', 1)
    epsilon = onnx_node.get_attribute_value('epsilon', 1e-3)

    # @TODO: Implement learning mode support
    # momentum = onnx_node.get_attribute_value('momentum', 0.99)

    if not is_test:
        raise NotImplementedError('BatchNormalization node (%s): only `is_test` mode is currently '
                                  'supported.', onnx_node.name)
    if not spatial:
        raise NotImplementedError('BatchNormalization node (%s): only `spatial` mode is currently '
                                  'supported.', onnx_node.name)

    mean = ng.broadcast(mean, x.shape, axis=1)
    scale = ng.broadcast(scale, x.shape, axis=1)
    var = ng.broadcast(var, x.shape, axis=1)
    bias = ng.broadcast(bias, x.shape, axis=1)
    epsilon = ng.broadcast(ng.constant(epsilon, dtype=get_dtype(x.get_element_type())),
                           x.shape, axis=1)
    return (scale * ((x - mean) * (1 / (ng.sqrt(var + epsilon)))) + bias)
コード例 #7
0
def test_variance_sqrt_inverse(transformer_factory, input_tensor):
    inputs = input_tensor
    targets = ng.placeholder(inputs.axes)

    epsilon = 1e-3

    inp_stat = ng.reciprocal(
        ng.sqrt(
            ng.variance(inputs, reduction_axes=inputs.axes.batch_axes()) +
            epsilon))
    err = ng.sum(inp_stat - targets, out_axes=())
    d_inputs = ng.deriv(err, inputs)
    with executor([err, d_inputs], inputs, targets) as comp_func:

        input_value = rng.uniform(-1, 1, inputs.axes)
        target_value = rng.uniform(-1, 1, targets.axes)
        ng_f_res, ng_b_res = comp_func(input_value, target_value)

        npv = np.var(input_value, axis=1, keepdims=True) + epsilon
        np_f_res = 1.0 / np.sqrt(npv)

        npv_delta = 2 * (input_value -
                         np.mean(input_value, axis=1, keepdims=True))

        np_b_res = -0.5 * np_f_res / npv * npv_delta

        np_f_res = np.sum(np_f_res - target_value)

        ng.testing.assert_allclose(np_f_res, ng_f_res, atol=1e-4, rtol=1e-4)
        ng.testing.assert_allclose(np_b_res, ng_b_res, atol=1e-4, rtol=1e-4)
コード例 #8
0
ファイル: test_ops.py プロジェクト: avitial/openvino
def unary_op(op_str, a):
    if op_str == "Abs":
        return ng.abs(a)
    elif op_str == "Acos":
        return ng.acos(a)
    elif op_str == "Asin":
        return ng.asin(a)
    elif op_str == "Atan":
        return ng.atan(a)
    elif op_str == "Ceiling":
        return ng.ceiling(a)
    elif op_str == "Cos":
        return ng.cos(a)
    elif op_str == "Cosh":
        return ng.cosh(a)
    elif op_str == "Floor":
        return ng.floor(a)
    elif op_str == "log":
        return ng.log(a)
    elif op_str == "exp":
        return ng.exp(a)
    elif op_str == "negative":
        return ng.negative(a)
    elif op_str == "Sign":
        return ng.sign(a)
    elif op_str == "Sin":
        return ng.sin(a)
    elif op_str == "Sinh":
        return ng.sinh(a)
    elif op_str == "Sqrt":
        return ng.sqrt(a)
    elif op_str == "Tan":
        return ng.tan(a)
    elif op_str == "Tanh":
        return ng.tanh(a)
コード例 #9
0
 def variable_update(self, variable, grad, scale_factor):
     m = ng.persistent_tensor(axes=grad.axes, initial_value=0.)
     v = ng.persistent_tensor(axes=grad.axes, initial_value=0.)
     updates = ng.sequential([
         ng.assign(m, m * self.beta_1 + (1 - self.beta_1) * grad),
         ng.assign(v, v * self.beta_2 + (1 - self.beta_2) * grad * grad),
         ng.assign(variable,
                   variable - (scale_factor * self.ell * m) / (ng.sqrt(v) + self.epsilon))
     ])
     return updates
コード例 #10
0
 def variable_update(self, variable, grad, scale_factor):
     epsilon, decay = (self.epsilon, self.decay_rate)
     grad = clip_gradient_value(grad, self.gradient_clip_value)
     state = ng.persistent_tensor(axes=variable.axes, initial_value=0.)
     updates = ng.sequential([
         ng.assign(state, decay * state + (1.0 - decay) * ng.square(grad)),
         ng.assign(variable, variable - ((scale_factor * grad * self.lrate)
                                         / (ng.sqrt(state + epsilon) + epsilon)))
     ])
     return updates
コード例 #11
0
 def variable_update(self, variable, grad, scale_factor):
     grad = clip_gradient_value(grad, self.gradient_clip_value)
     state = ng.persistent_tensor(axes=grad.axes, initial_value=0.)
     updates = ng.sequential([
         ng.assign(state, state + ng.square(grad)),
         ng.assign(
             variable, variable - (scale_factor * self.lrate * grad) /
             (ng.sqrt(state + self.epsilon)))
     ])
     return updates
コード例 #12
0
    def __call__(self, *args, **kwargs):
        if len(self.ops) == 0:
            self.beta_1 = ng.constant(self.beta_1, dtype=np.float32)
            self.beta_2 = ng.constant(self.beta_2, dtype=np.float32)
            self.t = ng.persistent_tensor(axes=(), initial_value=0)

        self.t = ng.sequential([ng.assign(self.t, self.t + 1), self.t])
        self.ell = self.lrate * ng.sqrt(1 - self.beta_2**self.t) / (
            1 - self.beta_1**self.t)

        return super(Adam, self).__call__(*args, **kwargs)
コード例 #13
0
def ReduceL2(onnx_node,
             ng_inputs):  # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode
    """Compute the L2 norm of the input tensor's element along the provided axes.

    :param onnx_node: The ONNX node representing this operation.
    :param ng_inputs: The input tensors.
    :return: The tensor with applied ReduceL2 operation.
    """
    square_node = ng_inputs[0] * ng_inputs[0]
    sum_node = make_reduction_op(ng.sum, onnx_node, square_node)
    return ng.sqrt(sum_node)
コード例 #14
0
ファイル: ops_unary.py プロジェクト: rsumner31/ngraph
    def Sqrt(self, cntk_op, inputs):
        """
        Returns element-wise square-root of inputs[0].

        Arguments:
            cntk_op: CNTK operation to be imported.
            inputs: List of inputs to this node.

        Returns:
            A ngraph Op.
        """
        assert len(inputs) == 1

        return ng.sqrt(inputs[0]).named(cntk_op.uid)
コード例 #15
0
 def variable_update(self, variable, grad, scale_factor):
     epsilon, decay = (self.epsilon, self.decay_rate)
     grad = clip_gradient_value(grad, self.gradient_clip_value)
     state = ng.persistent_tensor(axes=variable.axes, initial_value=1.)
     velocity = ng.persistent_tensor(
         axes=variable.axes, initial_value=0.).named(variable.name + '_vel')
     updates = ng.sequential([
         ng.assign(state, decay * state + (1.0 - decay) * ng.square(grad)),
         ng.assign(
             velocity, velocity * self.momentum +
             (self.lrate * scale_factor * grad / ng.sqrt(state + epsilon)) +
             self.lrate * self.wdecay * variable),
         ng.assign(variable, variable - velocity)
     ])
     return updates
コード例 #16
0
def ngraph_l2_norm(np_array):
    """
    TODO.

    Arguments:
      np_array: TODO

    Returns:
      TODO
    """
    axes = ()
    for i, l in enumerate(np_array.shape):
        axes += (ng.make_axis(name='axis%s' % i, length=l), )

    np_tensor = ng.constant(np_array, axes)
    var = ng.variable(axes, initial_value=np_tensor)
    return executor(ng.sqrt(ng.squared_L2(var)))()
コード例 #17
0
    def construct_batchnorm_fprop_pattern(self):
        """
        Generate graph op that represents a pattern for batchnorm fprop operation.
        self.gamma * ((in_obj - xmean) * ng.reciprocal(ng.sqrt(xvar + self.eps))) + self.beta
        Returns:
               Single pattern that matches batchnorm fprop op
        """
        self.batchnorm_fprop_input_tensor_label = "in_obj"
        self.batchnorm_fprop_gamma_label = "gamma"
        self.batchnorm_fprop_beta_label = "beta"
        self.batchnorm_fprop_variance_label = "variance"
        self.batchnorm_fprop_epsilon_label = "epsilon"
        self.batchnorm_fprop_mean_label = "mean"

        # bind the label to the op's which needed to be updated in the dict
        in_obj = PatternLabelOp(self.batchnorm_fprop_input_tensor_label,
                                (lambda op: isinstance(op, ContiguousOp)))
        flatten_tensor = PatternSkipOp(in_obj,
                                       (lambda op: isinstance(op, Flatten)))
        gamma = PatternLabelOp(self.batchnorm_fprop_gamma_label,
                               (lambda op: isinstance(op, BroadcastOp)))
        beta = PatternLabelOp(self.batchnorm_fprop_beta_label,
                              (lambda op: isinstance(op, BroadcastOp)))
        variance = PatternLabelOp(self.batchnorm_fprop_variance_label,
                                  (lambda op: isinstance(op, Divide)))
        epsilon = PatternLabelOp(self.batchnorm_fprop_epsilon_label,
                                 (lambda op: isinstance(op, BroadcastOp)))
        mean = PatternLabelOp(self.batchnorm_fprop_mean_label,
                              (lambda op: isinstance(op, Divide)))

        # construct the fprop batchnorm pattern matching the computation graph
        # ng.sqrt(xvar + self.eps)
        SqrtofVarianceAndEps = ng.sqrt(ng.add(variance, epsilon))
        # ng.reciprocal(ng.sqrt(xvar + self.eps))
        reciprocal_op = ng.reciprocal(SqrtofVarianceAndEps)
        reciprocal_op_w_braodcast = ng.PatternSkipOp(reciprocal_op,
                                                     lambda op: isinstance(op, BroadcastOp))

        mean_bcast = ng.PatternSkipOp(mean, lambda op: isinstance(op, BroadcastOp))
        # (in_obj - xmean) * ng.reciprocal(ng.sqrt(xvar + self.eps))
        mul_op_1 = ng.multiply(ng.subtract(flatten_tensor, mean_bcast), reciprocal_op_w_braodcast)
        # "self.gamma * ((in_obj - xmean) * ng.reciprocal(ng.sqrt(xvar + self.eps)))
        MultiplyGamma = ng.multiply(mul_op_1, gamma)
        # self.gamma * ((in_obj - xmean) * ng.reciprocal(ng.sqrt(xvar + self.eps))) + self.beta
        AddBeta = ng.Unflatten(ng.Add(MultiplyGamma, beta))
        return AddBeta
コード例 #18
0
    def __call__(self, cost_func):
        with ng.Op.saved_user_deps():
            state_updates, param_updates = [], []
            batch_cost = ng.sum(cost_func, out_axes=())
            batch_size = cost_func.axes.batch_axes()[0].length

            grads = [
                ng.deriv(batch_cost, v) / batch_size
                for v in batch_cost.variables()
            ]
            scale_factor = clip_gradient_norm(
                grads) if self.gradient_clip_norm else 1

            epsilon, decay = (self.epsilon, self.decay_rate)
            for i, (variable,
                    grad) in enumerate(zip(batch_cost.variables(), grads)):
                grad = clip_gradient_value(grad, self.gradient_clip_value)

                state = ng.persistent_tensor(axes=variable.axes,
                                             initial_value=0.)
                state_updates.append(
                    ng.assign(lvalue=state,
                              rvalue=decay * state +
                              (1.0 - decay) * ng.square(grad)).named(
                                  'state_u_%s' % i))

                param_updates.append(
                    ng.assign(
                        lvalue=variable,
                        rvalue=variable -
                        ((scale_factor * grad * self.learning_rate) /
                         (ng.sqrt(state + epsilon) + epsilon)),
                    ).named('var_u_%s' % i))

            lr_update = [
                ng.assign(
                    self.learning_rate,
                    self.schedule.get_learning_rate(self.learning_rate,
                                                    self.iteration_index))
            ]

            updates = ng.doall(state_updates + param_updates + lr_update)
            self.iteration_index += 1

        return updates
コード例 #19
0
ファイル: test_ops.py プロジェクト: ravirajpinnamaraju/ngraph
def unary_op(op_str, a):
    if op_str == 'Abs':
        return ng.abs(a)
    elif op_str == 'Acos':
        return ng.acos(a)
    elif op_str == 'Asin':
        return ng.asin(a)
    elif op_str == 'Atan':
        return ng.atan(a)
    elif op_str == 'Ceiling':
        return ng.ceiling(a)
    elif op_str == 'Cos':
        return ng.cos(a)
    elif op_str == 'Cosh':
        return ng.cosh(a)
    elif op_str == 'Floor':
        return ng.floor(a)
    elif op_str == 'log':
        return ng.log(a)
    elif op_str == 'exp':
        return ng.exp(a)
    elif op_str == 'negative':
        return ng.negative(a)
    elif op_str == 'Reverse':
        return ng.reverse(a, np.array([1]), 'index')
    elif op_str == 'Sign':
        return ng.sign(a)
    elif op_str == 'Sin':
        return ng.sin(a)
    elif op_str == 'Sinh':
        return ng.sinh(a)
    elif op_str == 'Sqrt':
        return ng.sqrt(a)
    elif op_str == 'Tan':
        return ng.tan(a)
    elif op_str == 'Tanh':
        return ng.tanh(a)
コード例 #20
0
ファイル: layer.py プロジェクト: kkasravi/ngraph
 def inference_outputs(self, in_obj):
     return self.gamma * (
         in_obj - self.gmean) / ng.sqrt(self.gvar + self.eps) + self.beta
コード例 #21
0
ファイル: optimizer.py プロジェクト: psdurley/ngraph
 def _pre_call_hook(self):
     self.t = ng.sequential([ng.assign(self.t, self.t + 1), self.t])
     self.ell = self.lrate * ng.sqrt(1 - self.beta_2**self.t) / (
         1 - self.beta_1**self.t)
コード例 #22
0
ファイル: ops_bridge.py プロジェクト: rsumner31/ngraph
def Sqrt(onnx_node, ng_inputs):  # type: (NodeWrapper, List[TensorOp]) -> Op
    return ng.sqrt(ng_inputs[0])
コード例 #23
0
def Sqrt(onnx_node, ng_inputs):  # type: (NodeWrapper, List[NgraphNode]) -> NgraphNode
    """Apply f(x) = x^0.5 (square root) to the input tensor elementwise."""
    return ng.sqrt(ng_inputs[0])
コード例 #24
0
    def __call__(self,
                 in_obj,
                 channel_axes="C",
                 spatial_axes=("D", "H", "W"),
                 **kwargs):
        """
        Arguments:
            in_obj (Op): Input op
            channel_axes (str): name of the expected channel axis type - defaults to "C"
            spatial_axes (tuple): names of expected depth, height and width axis types - defaults
                                  to "D", "H", and "W"
        """
        if isinstance(spatial_axes, dict):
            spatial_axes = tuple(
                spatial_axes.get(name, name) for name in ("D", "H", "W"))
        elif isinstance(spatial_axes, tuple):
            if len(spatial_axes) < 3:
                raise ValueError(
                    "spatial_axes must have length 3 (e.g. ('D', 'H', 'W'))")
            spatial_axes = tuple(
                name if name else default
                for name, default in zip(spatial_axes, ("D", "H", "W")))

        orig_axes = in_obj.axes
        in_obj = reorder_spatial_axes(in_obj, channel_axes, spatial_axes)
        channel_axes = in_obj.axes.get_by_names(channel_axes)
        spatial_axes = in_obj.axes.get_by_names(*spatial_axes)

        filter_axes = self._filter_axes(channel_axes, spatial_axes)

        # mark 'K' as a shadow axis for the initializers.
        axes_map = shadow_axes_map(filter_axes.find_by_name('K'))
        filter_axes = ng.make_axes([
            axis if axis.name != 'K' else list(axes_map.keys())[0]
            for axis in filter_axes
        ])

        if not self.initialized:
            if not self.weight_norm:
                self.W = ng.variable(axes=filter_axes,
                                     initial_value=self.init,
                                     metadata={
                                         "label": LABELS["weight"]
                                     }).named("W")
            else:
                self.v = ng.variable(axes=filter_axes,
                                     initial_value=self.init,
                                     metadata={
                                         "label": LABELS["weight"]
                                     }).named("v")
                out_axes = ng.make_axes(
                    [filter_axes.get_by_names("K__NG_SHADOW")])
                v_norm = ng.mean(ng.square(self.v), out_axes=out_axes)
                self.g = ng.variable(axes=out_axes,
                                     initial_value=self.init,
                                     metadata={
                                         "label": LABELS["weight"]
                                     }).named("g")
                self.W = self.g * self.v * ng.reciprocal(
                    ng.sqrt(v_norm + 1e-3))
        else:
            if filter_axes != self.W.axes:
                raise ValueError(
                    ("{layer_name} layer has already been initialized with an "
                     "input object which has resulted in filter axes: "
                     "{existing_filter_axes}. This new input object has axes: "
                     "{input_axes}, which implies the need for filter axes: "
                     "{new_filter_axes} which are different than the existing "
                     "filter axes.").format(
                         layer_name=self.name,
                         existing_filter_axes=self.W.axes,
                         input_axes=in_obj.axes,
                         new_filter_axes=filter_axes,
                     ))

        output = ng.map_roles(
            self._conv_op(in_obj, channel_axes, spatial_axes), axes_map)
        # Reorder the output to match the input order
        output_axis_order = ng.make_axes(
            [output.axes.find_by_name(ax.name)[0] for ax in orig_axes])
        # Remove introduced axes. If their length is > 1, then perhaps they should be kept
        slices = [
            0 if (ax not in orig_axes) and ax.length == 1 else slice(None)
            for ax in output.axes
        ]
        output = ng.tensor_slice(output, slices)
        # New axes with length > 1 may have been introduced. Add them to the end.
        output_axis_order = output_axis_order | output.axes
        return ng.axes_with_order(output, output_axis_order)