Ejemplo n.º 1
0
def _convert_transpose(converter: ONNXConverter, onnx_op: INodeProto):
    x = converter.get_variable(onnx_op.input[0])
    attrs = attribute_dict(onnx_op)

    y, = Transpose(None)(x)
    perm = list(attrs["perm"].ints if "perm" in
                attrs else reversed(range(x.ndim)))
    y.change_order(Order([x.order.axes[i] for i in perm]))

    converter.set_variable(onnx_op.output[0], y)
Ejemplo n.º 2
0
def _replace_input(op: Operator, var_name: str, target_orders: Union[Order, List[Order]]):
    orig_var = op.inputs[var_name]
    if isinstance(target_orders, Order):
        target_orders = [target_orders]
    if orig_var.order in target_orders:
        return False
    trans, = Transpose(None)(orig_var)
    trans.change_order(target_orders[0])
    op.remove_input(orig_var)
    op.append_input(var_name, trans)
    return True
Ejemplo n.º 3
0
def _replace_input(op: Operator, var_name: str,
                   target_orders: Union[Order, List[Order]]):
    v = op.inputs[var_name]

    if isinstance(target_orders, Order):
        target_orders = [target_orders]
    if v.order in target_orders:
        return False

    v_new, = Transpose(None)(v)
    op.replace_input(v, v_new, with_assert=False)
    v_new.change_order(target_orders[0])
    return True
Ejemplo n.º 4
0
def transpose_handler(converter: TensorFlowConverter, tf_op: "tf.Operation"):
    x = converter.get_variable(tf_op.inputs[0])
    indices = converter.get_variable(tf_op.inputs[1])

    if not isinstance(indices, ConstantVariable):
        raise NotImplementedError(
            "[TensorFlowConverter] Operator 'Transpose' with dynamic indices is not supported yet."
        )

    indices = indices.data.astype(int).flatten().tolist()  # type: List[int]
    y, = Transpose(None)(x)
    y.change_order(Order([x.order.axes[i] for i in indices]))

    converter.set_variable(tf_op.outputs[0], y)
Ejemplo n.º 5
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.filter_nodes(traverse.listup_operators(graph),
                                        Linear):
            x = op.inputs["x"]
            w = op.inputs["w"]
            y = op.outputs["y"]

            flag_changed = True
            op.remove_all()

            a_k = Axis.C
            a_n = w.order.axes[0] if w.order.axes[1] == a_k else w.order.axes[1]
            axes_m = [a for a in x.order.axes if a != a_k]

            K = x.shape_dict[a_k]
            M = x.size // K
            N = w.shape_dict[a_n]

            x, = Transpose(None)(x)
            x.change_order(Order([a_k] + axes_m))

            w, = Transpose(None)(w)
            w.change_order(Order([a_k, a_n]))

            new_y, = Sgemm(None,
                           M=M,
                           N=N,
                           K=K,
                           out_shape=[x.shape_dict[a] for a in axes_m] + [N],
                           out_order=Order(axes_m + [a_n]),
                           transpose_A=False,
                           transpose_B=True)(x, w)
            new_y, = Transpose(None)(new_y)

            OptimizeRule.replace_variable(graph, new_y, y)

        return graph, flag_changed
Ejemplo n.º 6
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.listup_operators(graph):
            if isinstance(op, Reshape):
                flag_changed |= _replace_input(op, "x",
                                               op.parameters["in_order"])
                flag_changed |= _replace_output(op, "y",
                                                op.parameters["out_order"])
                continue

            elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D,
                                 Deconvolution2D, Space2Depth, Depth2Space)):
                flag_changed |= _replace_input(op, "x", OrderNHWC)
                flag_changed |= _replace_output(op, "y", OrderNHWC)
                continue

            elif isinstance(op, Softmax):
                x = op.inputs["x"]
                y = op.outputs["y"]
                target_axis = op.parameters["axis"]

                if not (x.ndim == 2
                        and x.order.axes_dict[target_axis] == x.ndim - 1):
                    """
                    Before)
                    | x   |              | y   |
                    |-----| -{softmax}-> |-----|
                    | XYZ |   axis=Y     | XYZ |
                    
                    After)
                    | x   |                | hx1 |              | hx2 |              | hy1 |              | hy2 |                | y   |
                    |-----| -{transpose}-> |-----| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |-----| -{transpose}-> |-----|
                    | XYZ |                | XZY |              | NC  |   axis=C     | NC  |              | XZY |                | XYZ |
                                              :                    :
                                        order_nd = XZY       order_2d = NC
                    """
                    op.remove_all()

                    axes_nd = list(x.order.axes)
                    axes_nd.remove(target_axis)
                    axes_nd.append(target_axis)
                    order_nd = Order(axes_nd)
                    shape_nd = tuple([x.shape_dict[axis] for axis in axes_nd])

                    order_2d = OrderNC
                    shape_2d = tuple([
                        x.size // x.shape_dict[target_axis],
                        x.shape_dict[target_axis]
                    ])

                    if x.order == order_nd:
                        hx1 = x

                    else:
                        hx1, = Transpose(None)(x)
                        hx1.change_order(order_nd)
                        flag_changed = True

                    if hx1.order == order_2d and hx1.shape == shape_2d:
                        hx2 = hx1

                    else:
                        hx2, = Reshape(None,
                                       in_order=hx1.order,
                                       out_order=order_2d,
                                       out_shape=shape_2d)(hx1)
                        flag_changed = True

                    hy1, = Softmax(None, axis=Axis.C)(hx2)

                    if hy1.order == order_nd and hy1.shape == shape_nd:
                        hy2 = hy1

                    else:
                        hy2, = Reshape(None,
                                       in_order=hy1.order,
                                       out_order=order_nd,
                                       out_shape=shape_nd)(hy1)
                        flag_changed = True

                    if hy2.order == y.order:
                        y_dummy = hy2

                    else:
                        y_dummy, = Transpose(None)(hy2)
                        y_dummy.change_order(y.order)
                        flag_changed = True

                    y_dummy.replace(y)

                    continue

        return graph, flag_changed
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        matches = traverse.search_sub_structure(
            graph, [Sgemm, Variable, ElementwiseMul])
        while len(matches) > 0:
            match = matches.pop()
            sgemm = match[0]  # type: Sgemm
            elementwise_mul = match[2]  # type:  ElementwiseMul

            out_order = sgemm.parameters["out_order"]
            out_shape = sgemm.parameters["out_shape"]

            axis_k = Axis('AxisK')

            if not isinstance(sgemm.inputs["A"],
                              ConstantVariable) and not isinstance(
                                  sgemm.inputs["B"], ConstantVariable):
                # neither x nor w1 is constant
                continue

            elif isinstance(sgemm.inputs["A"], ConstantVariable):
                w1 = sgemm.inputs["A"]  # type: ConstantVariable

                if sgemm.transpose_A:
                    # w1.shape = (M, K)

                    shape = []
                    axes = []
                    for axis, size in zip(out_order.axes, out_shape):
                        shape.append(size)
                        axes.append(axis)

                        if mul(shape) >= sgemm.M:
                            break

                    if mul(shape) != sgemm.M:
                        # output axes are derived from both w1 and x
                        continue

                    w1_virtual_order = Order(axes + [axis_k])
                    w1_virtual_shape = shape + [sgemm.K]

                else:
                    # w1.shape = (K, M)

                    shape = [sgemm.K]
                    axes = [axis_k]
                    for axis, size in zip(out_order.axes, out_shape):
                        shape.append(size)
                        axes.append(axis)

                        if mul(shape) >= w1.size:
                            break

                    if mul(shape) != w1.size:
                        # output axes are derived from both w1 and x
                        continue

                    w1_virtual_order = Order(axes)
                    w1_virtual_shape = shape

            else:
                w1 = sgemm.inputs["B"]  # type: ConstantVariable

                if sgemm.transpose_B:
                    # w1.shape = (K, N)

                    shape = []
                    axes = []
                    for axis, size in reversed(
                            list(zip(out_order.axes, out_shape))):
                        shape.insert(0, size)
                        axes.insert(0, axis)

                        if mul(shape) >= sgemm.N:
                            break

                    if mul(shape) != sgemm.N:
                        # output axes are derived from both w1 and x
                        continue

                    w1_virtual_order = Order([axis_k] + axes)
                    w1_virtual_shape = [sgemm.K] + shape

                else:
                    # w1.shape = (N, K)
                    shape = [sgemm.K]
                    axes = [axis_k]
                    for axis, size in reversed(
                            list(zip(out_order.axes, out_shape))):
                        shape.insert(0, size)
                        axes.insert(0, axis)

                        if mul(shape) >= w1.size:
                            break

                    if mul(shape) != w1.size:
                        # output axes are derived from both w1 and x
                        continue

                    w1_virtual_order = Order(axes)
                    w1_virtual_shape = shape

            h = sgemm.outputs["C"]  # type: Variable

            x0 = elementwise_mul.inputs["x0"]
            x1 = elementwise_mul.inputs["x1"]
            if h == x1:
                if not isinstance(x0, ConstantVariable):
                    # w2 is not constant
                    continue

                w2 = x0  # type: ConstantVariable

            else:
                if not isinstance(x1, ConstantVariable):
                    # w2 is not constant
                    continue

                w2 = x1  # type: ConstantVariable

            y = elementwise_mul.outputs["y"]  # type: Variable

            if not all(axis in w1_virtual_order.axes
                       for axis in w2.order.axes):
                # w2's axes are derived from both w1 and x
                continue

            elementwise_mul.remove_all()
            y_dummy, = Transpose(None)(h)
            y_dummy.change_order(y.order)
            y_dummy.replace(y)

            w2.change_order(w1_virtual_order)
            w_new = ConstantVariable(
                w1.data.reshape(w1_virtual_shape),
                w1_virtual_order) * w2  # type: ConstantVariable
            w1.data = w_new.data.reshape(w1.shape)

            flag_changed = True
            matches = traverse.search_sub_structure(
                graph, [Sgemm, Variable, ElementwiseMul])

        return graph, flag_changed
Ejemplo n.º 8
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.listup_operators(graph):
            if isinstance(op, Transpose):
                x = op.inputs["x0"]
                y = op.outputs["y"]
                if x.order == y.order:
                    op.remove_all()
                    x.replace(y)
                    flag_changed = True

                if all(isinstance(op2, (Elementwise, SplitAxis)) for op2 in y.input_to):
                    op.remove_all()
                    for op2 in list(y.input_to):
                        name = op2._get_input_name(y)
                        op2.remove_input(y)
                        op2.append_input(name, x)

            elif isinstance(op, Reshape):
                flag_changed |= _replace_input(op, "x", op.parameters["in_order"])
                flag_changed |= _replace_output(op, "y", op.parameters["out_order"])

            elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D, Deconvolution2D)):
                flag_changed |= _replace_input(op, "x", OrderNHWC)
                flag_changed |= _replace_output(op, "y", OrderNHWC)

            elif isinstance(op, Softmax):
                x = op.inputs["x"]
                y = op.outputs["y"]

                if x.ndim > 2:
                    """
                    Before)
                    | x    |              | y    |
                    |------| -{softmax}-> |------|
                    | NCHW |              | NCHW |

                    After)
                    | x    |                | hx1  |              | hx2 |              | hy1 |              | hy2  |                | y    |
                    |------| -{transpose}-> |------| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |------| -{transpose}-> |------|
                    | NCHW |                | NHWC |              | NC  |              | NC  |              | NHWC |                | NCHW |
                    """
                    op.remove_all()

                    target_axis = op.parameters["axis"]
                    axes_nd = list(x.order.axes)
                    axes_nd.remove(target_axis)
                    axes_nd.append(target_axis)
                    order_nd = Order(axes_nd)
                    shape_nd = [x.shape_dict[axis] for axis in axes_nd]

                    order_2d = OrderNC
                    shape_2d = [x.size // x.shape_dict[target_axis], x.shape_dict[target_axis]]

                    hx1, = Transpose(None)(x)
                    hx1.change_order(order_nd)

                    hx2, = Reshape(None, in_order=hx1.order, out_order=order_2d, out_shape=shape_2d)(hx1)

                    hy1, = Softmax(None, axis=Axis.C)(hx2)

                    hy2, = Reshape(None, in_order=hy1.order, out_order=order_nd, out_shape=shape_nd)(hy1)

                    y_dummy, = Transpose(None)(hy2)
                    y_dummy.change_order(y.order)

                    y_dummy.replace(y)
                    flag_changed = True

                else:
                    flag_changed |= _replace_input(op, "x", OrderNC)
                    flag_changed |= _replace_output(op, "y", OrderNC)

        return graph, flag_changed