Example #1
0
def _convert_average_pooling1d(converter: KerasConverter,
                               k_op: "keras.layers.AveragePooling1D"):
    x = converter.get_variable(converter.get_input_tensor(k_op)[0])

    # FIXME: More effective implementation
    y, = Reshape(None,
                 in_order=x.order,
                 out_order=OrderNHWC,
                 out_shape=[x.shape[0], x.shape[1], 1, x.shape[2]])(x)

    if k_op.padding == "valid":
        padding = (0, 0)

    elif k_op.padding == "same":
        padding = (k_op.pool_size[0] // 2, 0)

    else:
        raise NotImplementedError(f"Unknown padding: {k_op.padding}")

    y, = AveragePooling2D(None,
                          ksize=(k_op.pool_size[0], 1),
                          stride=(1, 1),
                          padding=padding)(y)
    z, = Reshape(None,
                 in_order=y.order,
                 out_order=OrderNTC,
                 out_shape=[y.shape[0], y.shape[1], y.shape[3]])(y)

    converter.set_variable(converter.get_output_tensor(k_op)[0], z)
    def optimize_operator(self, graph: Graph, op: Reshape):
        x = op.inputs["x"]
        y = op.outputs["y"]

        if x.order == y.order and x.shape == y.shape:
            _remove_unary_operator(graph, op)
            return True

        if x.shape == y.shape:
            op.remove_all()
            y_dummy, = ReinterpretAxis(None,
                                       in_order=x.order,
                                       out_order=y.order)(x)
            y_dummy.replace(y)
            return True

        if isinstance(x, ConstantVariable) and x.output_from is None:
            _remove_unary_operator(graph, op)
            x.change_order(y.order)
            return True

        if all([
                y not in graph.outputs,
                all(x.stride_dict[axis] == y.stride_dict[axis] for axis in
                    [axis for axis in x.order.axes if axis in y.order.axes]),
                all(isinstance(op2, Elementwise) for op2 in y.input_to)
        ]):
            _remove_unary_operator(graph, op)
            return True

        return False
Example #3
0
def template(x_order=OrderNHWC,
             x_shape=(2, 3, 4, 5),
             y_order=OrderNHWC,
             y_shape=(1, 12, 2, 5),
             description: str = ""):
    vx = np.random.rand(*x_shape) - 0.5

    x = Variable(vx.shape, order=OrderNHWC)
    y, = Reshape(None, in_order=x_order, out_order=y_order,
                 out_shape=y_shape)(x)

    x.change_order(x_order)
    y.change_order(y_order)

    generate_kernel_test_case(
        description=f"Reshape {description}",
        graph=Graph([x], [y]),
        inputs={
            x:
            np.transpose(vx, [OrderNHWC.axes_dict[a]
                              for a in x.order.axes]).flatten()
        },
        expected={
            y:
            np.transpose(vx, [OrderNHWC.axes_dict[a]
                              for a in y.order.axes]).flatten()
        },
    )
Example #4
0
def _convert_global_average_pooling1d(converter: KerasConverter, k_op: keras.layers.GlobalAveragePooling1D):
    x = converter.get_variable(converter.get_input_tensor(k_op)[0])

    # FIXME: More effective implementation
    y, = Reshape(None, in_order=OrderNTC, out_order=OrderNHWC, out_shape=[x.shape[0], x.shape[1], 1, x.shape[2]])(x)
    y, = AveragePooling2D(None, ksize=(x.shape[1], 1), stride=(1, 1), padding=(0, 0))(y)

    # flatten without changing memory layout
    z, = Reshape(None, in_order=y.order, out_order=OrderNC, out_shape=[y.shape[0], mul(y.shape[1:])])(y)
    converter.set_variable(converter.get_output_tensor(k_op)[0], z)
Example #5
0
def _convert_reshape(converter: KerasConverter, k_op: "keras.layers.Reshape"):
    x = converter.get_variable(converter.get_input_tensor(k_op)[0])

    target_shape = [x.shape[0]] + list(k_op.target_shape)
    if len(target_shape) == 2:
        target_order = OrderNC

    elif len(target_shape) == 3:
        target_order = OrderNTC

    elif len(target_shape) == 4:
        target_order = OrderNHWC

    else:
        raise NotImplementedError(
            f"[KerasConverter] Unknown default order: shape={target_shape}")

    console.warning(
        "[KerasConverter] keras.layers.Reshape is parsed new data order as default order (OrderNC in 2D, "
        "OrderNTC in 3D, OrderNHWC in 4D). To handle this, please overwrite keras.layers.Reshape converter "
        "handler.")

    y, = Reshape(None,
                 in_order=x.order,
                 out_order=target_order,
                 out_shape=target_shape)(x)
    converter.set_variable(converter.get_output_tensor(k_op)[0], y)
Example #6
0
def convert_layer_global_average_pooling2d(
        converter: KerasConverter,
        k_op: "keras.layers.GlobalAveragePooling2D"):
    x = converter.get_variable(converter.get_input_tensor(k_op)[0])
    if k_op.data_format == "channels_first":
        assert x.order == OrderNCHW

    elif k_op.data_format == "channels_last":
        assert x.order == OrderNHWC

    else:
        raise ValueError(
            f"[KerasConverter] Unknown data format: {k_op.data_format}")

    y, = AveragePooling2D(None,
                          ksize=(x.shape_dict[Axis.H], x.shape_dict[Axis.W]),
                          stride=(1, 1),
                          padding=(0, 0))(x)

    # flatten without changing memory layout
    z, = Reshape(None,
                 in_order=y.order,
                 out_order=OrderNC,
                 out_shape=[y.shape[0], mul(y.shape[1:])])(y)
    converter.set_variable(converter.get_output_tensor(k_op)[0], z)
Example #7
0
def _convert_linear_function(
        converter: ChainerConverter,
        c_op: "chainer.functions.connection.linear.LinearFunction"):
    x = converter.get_variable(c_op.inputs[0])
    w = converter.get_variable(c_op.inputs[1])  # type: ConstantVariable

    x2, = Reshape(None,
                  in_order=x.order,
                  out_order=OrderNC,
                  out_shape=[x.shape[0], mul(x.shape[1:])])(x)
    w2, = ReinterpretAxis(None, in_order=w.order, out_order=OrderNC)(w)
    w2, = Transpose(None)(w2)
    w2.change_order(OrderCN)

    y, = Linear(None)(x2, w2)
    y, = ReinterpretAxis(None,
                         in_order=y.order,
                         out_order=Order([x.order.axes[0],
                                          w.order.axes[0]]))(y)

    if len(c_op.inputs) == 3:
        # with bias
        b = converter.get_variable(c_op.inputs[2])
        check_broadcast_constraints(y, b)
        y = y + b

    converter.set_variable(c_op.outputs[0](), y)
Example #8
0
def _convert_reshape(converter: ChainerConverter,
                     c_op: "chainer.functions.Reshape"):
    assert len(c_op.inputs) == 1, \
        f"For 'Reshape' operator in chainer, expected number of inputs is 1, but actual is {len(c_op.inputs)}"

    x = converter.get_variable(c_op.inputs[0])

    out_shape = list(c_op.shape)  # c_op.shape is tuple
    if len(out_shape) == 1:
        out_order = OrderC
    elif len(out_shape) == 2:
        out_order = OrderNC
    elif len(out_shape) == 4:
        out_order = OrderNCHW
    else:
        raise NotImplementedError(
            "Reshaping into dimensions none of 1, 2, 4 is not supported.")
    assert mul(out_shape) == x.size

    y, = Reshape(None,
                 in_order=x.order,
                 out_order=out_order,
                 out_shape=out_shape)(x)

    converter.set_variable(c_op.outputs[0](), y)
Example #9
0
def _convert_reshape(converter: ONNXConverter, onnx_op: INodeProto):
    x = converter.get_variable(onnx_op.input[0])
    if converter.opset_version >= 5:
        # output shape is specified by onnx_op.input[1]
        # It have to be ConstantVariable.
        # TODO: test for different operator set version
        shape_var = converter.get_variable(onnx_op.input[1])
        assert isinstance(
            shape_var, ConstantVariable
        ), "Shape specifier of Reshape operator have to be constant."
        out_shape = [int(d) for d in shape_var.data]
    else:
        # Reshape-1
        attrs = attribute_dict(onnx_op)
        out_shape = [
            r if s == 0 else s for r, s in zip(x.shape, attrs["shape"].ints)
        ]

    if -1 in out_shape:
        i = out_shape.index(-1)
        out_shape.remove(-1)
        out_shape.insert(i, x.size // mul(out_shape))

    out_order = Order([None] * len(out_shape))

    y, = Reshape(None,
                 in_order=x.order,
                 out_order=out_order,
                 out_shape=out_shape)(x)
    converter.set_variable(onnx_op.output[0], y)
Example #10
0
def _convert_repeat_vector(converter: KerasConverter,
                           k_op: "keras.layers.RepeatVector"):
    x = converter.get_variable(converter.get_input_tensor(k_op)[0])

    assert x.order == OrderNC, f"[KerasConverter] Currently only OrderNC is supported for input variable order of " \
                               f"keras.layers.RepeatVector: x.order={x.order}"

    N = x.shape_dict[Axis.N]
    n = k_op.n
    C = x.shape_dict[Axis.C]

    # TODO: Implement more efficient version
    # ex) x.shape=(N=2, C=3), n=2
    #
    #  x(N, C)  *      w(C, n*C)     =      y(N, n*C)     =       y(N, n, C)
    # -----------------------------------------------------------------------------
    # [1, 2, 3]   [1, 0, 0, 1, 0, 0]   [1, 2, 3, 1, 2, 3]   [[1, 2, 3], [1, 2, 3]]
    # [4, 5, 6] * [0, 1, 0, 0, 1, 0] = [4, 5, 6, 4, 5, 6] = [[4, 5, 6], [4, 5, 6]]
    #             [0, 0, 1, 0, 0, 1]
    #

    w = ConstantVariable(np.tile(np.eye(C), (1, n)), OrderCN)

    y, = Linear(None)(x, w)
    y, = Reshape(None,
                 in_order=OrderNC,
                 out_order=OrderNTC,
                 out_shape=[N, n, C])(y)
    converter.set_variable(converter.get_output_tensor(k_op)[0], y)
Example #11
0
def template(in_order, in_shape, out_order, out_shape):
    op = Reshape(None,
                 in_order=in_order,
                 out_order=out_order,
                 out_shape=[out_shape[a] for a in out_order.axes])
    x = Variable([in_shape[a] for a in in_order.axes], in_order)
    y, = op(x)
    assert_shape(y, out_shape)
    def optimize_operator(self, graph: Graph, op: Reshape):
        x = op.inputs["x"]
        y = op.outputs["y"]

        if x.order == y.order and x.shape == y.shape:
            # no reshape is required
            _remove_unary_operator(graph, op)
            return True

        if x.shape == y.shape:
            # only reinterpret_axis is required
            op.remove_all()
            y_dummy = x.reinterpret_axes(y.order)
            OptimizeRule.replace_variable(graph, y_dummy, y)
            return True

        return False
Example #13
0
def _convert_flatten(converter: KerasConverter, k_op: "keras.layers.Flatten"):
    x = converter.get_variable(converter.get_input_tensor(k_op)[0])

    # flatten without changing memory layout
    y, = Reshape(None,
                 in_order=x.order,
                 out_order=OrderNC,
                 out_shape=[x.shape[0], mul(x.shape[1:])])(x)
    converter.set_variable(converter.get_output_tensor(k_op)[0], y)
Example #14
0
    def optimize_operator(self, graph: Graph, op: Reshape):
        x = op.inputs["x"]
        y = op.outputs["y"]

        if x.order == y.order and x.shape == y.shape:
            # no reshape is occurred
            _remove_unary_operator(graph, op)
            return True

        if x.shape == y.shape:
            # only reinterpret_axis is occurred
            op.remove_all()
            y_dummy, = ReinterpretAxis(None,
                                       in_order=x.order,
                                       out_order=y.order)(x)
            y_dummy.replace(y)
            return True

        return False
Example #15
0
def _convert_flatten(converter: ChainerConverter,
                     c_op: "chainer.functions.Flatten"):
    x = converter.get_variable(c_op.inputs[0])
    y, = Reshape(None, in_order=x.order, out_shape=[x.size], out_order=OrderC)
    converter.set_variable(c_op.outputs[0](), y)

    console.warning(
        "[ChainerConverter] In chainer.functions.Flatten, output data order is parsed as OrderC. To "
        "customize this, please overwrite chainer.functions.Flatten converter handler."
    )
Example #16
0
def _convert_reshape(converter: ONNXConverter, onnx_op: INodeProto):
    x = converter.get_variable(onnx_op.input[0])
    attrs = attribute_dict(onnx_op)
    out_shape = [r if s == 0 else s for r, s in zip(x.shape, attrs["shape"].ints)]

    if -1 in out_shape:
        i = out_shape.index(-1)
        out_shape.remove(-1)
        out_shape.insert(i, x.size // mul(out_shape))

    out_order = Order([None] * len(out_shape))

    y, = Reshape(None, in_order=x.order, out_order=out_order, out_shape=out_shape)(x)
    converter.set_variable(onnx_op.output[0], y)
Example #17
0
def _convert_reshape(converter: ChainerConverter,
                     c_op: "chainer.functions.Reshape"):
    x = converter.get_variable(c_op.inputs[0])

    out_shape = c_op.shape
    # noinspection PyTypeChecker
    out_order = Order([AxisVar() for _ in out_shape])
    assert mul(
        out_shape
    ) == x.size, f"[ChainerConverter] Shape mismatch: mul(out_shape)={mul(out_shape)}, x.size={x.size}"

    y, = Reshape(None,
                 in_order=x.order,
                 out_order=out_order,
                 out_shape=out_shape)(x)

    converter.set_variable(c_op.outputs[0](), y)
Example #18
0
def _split_reshape(graph: Graph, op: Reshape, v: Variable,
                   v_pair: Sequence[Variable], axis: Axis):
    x = op.inputs["x"]
    y = op.outputs["y"]
    s1 = v_pair[0].shape_dict[axis]
    s2 = v_pair[1].shape_dict[axis]
    op.remove_all()
    in_order = op.in_order
    out_order = op.out_order
    x_shape = [x.shape_dict[a] for a in in_order.axes]
    y_shape = [y.shape_dict[a] for a in out_order.axes]

    if v == x:
        """
        before)

            x -{reshape}- y

        after)

            x_0 -{reshape}- t_0 -+
                                 +-{concat[axis_k]}- t -{reshape}- y
            x_1 -{reshape}- t_1 -+

        shape and order is changed as follows:

                  x.shape = [dx_0, dx_1, ..., dx_m,   ..., dx_M-1]
                x_0.shape = [dx_0, dx_1, ..., dx_m/2, ..., dx_M-1]
            ---------------------------------------------------------------------------------
                t_0.shape = [dy_0, dy_1, ..., dy_n,   ..., dy_k/2, ..., dy_N-1]
                  t.shape = [dy_0, dy_1, ..., dy_n*2, ..., dy_k/2, ..., dy_N-1]
                  y.shape = [dy_0, dy_1, ..., dy_n,   ..., dy_k,   ..., dy_N-1]

            m: split target axis

            find axis_k and axis_n, which satisfies follow conditions

                dy_n * dy_n+1 * ... * dy_N-1 == dx_m * dx_m+1 * ... * dx_M-1
                dy_k % 2 == 0
                n <= k
        """

        x_0, x_1 = v_pair
        dx_prod = mul(x_shape[in_order.axes_dict[axis]:])
        dy_prod = 1
        axis_k_candidate = []
        for axis_n in reversed(out_order.axes):
            dy_prod *= y.shape_dict[axis_n]
            if y.shape_dict[axis_n] % 2 == 0:
                axis_k_candidate.append(axis_n)

            if dx_prod == dy_prod:
                # Split most large axis
                axis_k = (sorted(axis_k_candidate,
                                 key=lambda a: y.shape_dict[a],
                                 reverse=True))[0]

                t_0_shape = [y.shape_dict[a] for a in out_order.axes]
                t_0_shape[out_order.axes_dict[axis_k]] = t_0_shape[
                    out_order.axes_dict[axis_k]] // 2  # TODO
                t_0, = Reshape(None,
                               in_order=in_order,
                               out_order=out_order,
                               out_shape=t_0_shape)(x_0)

                t_1_shape = [y.shape_dict[a] for a in out_order.axes]
                t_1_shape[out_order.axes_dict[axis_k]] = t_1_shape[
                    out_order.axes_dict[axis_k]] // 2  # TODO
                t_1, = Reshape(None,
                               in_order=in_order,
                               out_order=out_order,
                               out_shape=t_1_shape)(x_1)

                t, = Concat(None, axis=axis_n)(t_0, t_1)
                y_new, = Reshape(None,
                                 in_order=out_order,
                                 out_order=out_order,
                                 out_shape=y_shape)(t)

                OptimizeRule.replace_variable(graph, y_new.transpose_like(y),
                                              y)
                break

            elif dy_prod > (s1 + s2) * dx_prod:
                raise NotImplementedError(
                    f"Variable is too large to handle in WebGL backend: {v}")

    elif v == y:
        """
        algorithm is almost same as the case `v == x` (above).

        before)

            x -{reshape}- y

        after)

                                    +- t_0 -{reshape}- y_0
            x -{reshape}- t-{split}-+
                                    +- t_1 -{reshape}- y_1

        shape and order is changed as follows:

                  x.shape = [dx_0, dx_1, ..., dx_m,   ..., dx_k,   ..., dx_M-1]
                  t.shape = [dx_0, dx_1, ..., dx_m*2, ..., dx_k/2, ..., dx_M-1]
                t_0.shape = [dx_0, dx_1, ..., dx_m,   ..., dx_k/2, ..., dx_M-1]
            ---------------------------------------------------------------------------------
                y_0.shape = [dy_0, dy_1, ..., dy_n/2, ..., dy_N-1]
                  y.shape = [dy_0, dy_1, ..., dy_n,   ..., dy_N-1]

            m: split target axis

            find axis_k and axis_m, which satisfies follow conditions

                dx_m * dx_m+1 * ... * dx_M-1 == dy_n * dy_n+1 * ... * dy_N-1
                dx_k % 2 == 0
                m <= k
        """

        y_0, y_1 = v_pair
        dx_prod = 1
        dy_prod = mul(x_shape[out_order.axes_dict[axis]:])
        axis_k_candidate = []
        for axis_m in reversed(in_order.axes):
            dx_prod *= x.shape_dict[axis_m]
            if x.shape_dict[axis_m] % 2 == 0:
                axis_k_candidate.append(axis_m)

            if dx_prod == dy_prod:
                # Split most large axis
                axis_k = (sorted(axis_k_candidate,
                                 key=lambda a: x.shape_dict[a],
                                 reverse=True))[0]

                t_shape = [x.shape_dict[a] for a in in_order.axes]
                t_shape[in_order.axes_dict[axis_m]] = 2 * t_shape[
                    in_order.axes_dict[axis_m]]  # TODO
                t_shape[in_order.axes_dict[axis_k]] = t_shape[
                    in_order.axes_dict[axis_k]] // 2  # TODO
                t, = Reshape(None,
                             in_order=in_order,
                             out_order=in_order,
                             out_shape=t_shape)(x)

                t_0, t_1 = SplitAxis(None,
                                     axis=axis_m,
                                     sections=[t.shape_dict[axis_m] // 2
                                               ])(t)  # TODO

                y_0_new, = Reshape(None,
                                   in_order=in_order,
                                   out_order=out_order,
                                   out_shape=y_0.shape)(t_0)
                y_1_new, = Reshape(None,
                                   in_order=in_order,
                                   out_order=out_order,
                                   out_shape=y_1.shape)(t_1)

                OptimizeRule.replace_variable(graph, y_0_new.reshape_like(y_0),
                                              y_0)
                OptimizeRule.replace_variable(graph, y_1_new.reshape_like(y_1),
                                              y_1)
                break

            elif dx_prod > dy_prod:
                raise NotImplementedError(
                    f"Variable is too large to handle in WebGL backend: {v}")

    else:
        raise UnexpectedAndPleaseReportError
Example #19
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.listup_operators(graph):
            if isinstance(op, Transpose):
                x = op.inputs["x0"]
                y = op.outputs["y"]
                if x.order == y.order:
                    op.remove_all()
                    x.replace(y)
                    flag_changed = True

                if all(isinstance(op2, (Elementwise, SplitAxis)) for op2 in y.input_to):
                    op.remove_all()
                    for op2 in list(y.input_to):
                        name = op2._get_input_name(y)
                        op2.remove_input(y)
                        op2.append_input(name, x)

            elif isinstance(op, Reshape):
                flag_changed |= _replace_input(op, "x", op.parameters["in_order"])
                flag_changed |= _replace_output(op, "y", op.parameters["out_order"])

            elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D, Deconvolution2D)):
                flag_changed |= _replace_input(op, "x", OrderNHWC)
                flag_changed |= _replace_output(op, "y", OrderNHWC)

            elif isinstance(op, Softmax):
                x = op.inputs["x"]
                y = op.outputs["y"]

                if x.ndim > 2:
                    """
                    Before)
                    | x    |              | y    |
                    |------| -{softmax}-> |------|
                    | NCHW |              | NCHW |

                    After)
                    | x    |                | hx1  |              | hx2 |              | hy1 |              | hy2  |                | y    |
                    |------| -{transpose}-> |------| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |------| -{transpose}-> |------|
                    | NCHW |                | NHWC |              | NC  |              | NC  |              | NHWC |                | NCHW |
                    """
                    op.remove_all()

                    target_axis = op.parameters["axis"]
                    axes_nd = list(x.order.axes)
                    axes_nd.remove(target_axis)
                    axes_nd.append(target_axis)
                    order_nd = Order(axes_nd)
                    shape_nd = [x.shape_dict[axis] for axis in axes_nd]

                    order_2d = OrderNC
                    shape_2d = [x.size // x.shape_dict[target_axis], x.shape_dict[target_axis]]

                    hx1, = Transpose(None)(x)
                    hx1.change_order(order_nd)

                    hx2, = Reshape(None, in_order=hx1.order, out_order=order_2d, out_shape=shape_2d)(hx1)

                    hy1, = Softmax(None, axis=Axis.C)(hx2)

                    hy2, = Reshape(None, in_order=hy1.order, out_order=order_nd, out_shape=shape_nd)(hy1)

                    y_dummy, = Transpose(None)(hy2)
                    y_dummy.change_order(y.order)

                    y_dummy.replace(y)
                    flag_changed = True

                else:
                    flag_changed |= _replace_input(op, "x", OrderNC)
                    flag_changed |= _replace_output(op, "y", OrderNC)

        return graph, flag_changed
Example #20
0
def _split_sgemm(graph: Graph, op: Sgemm, v: Variable, v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    s2 = v_pair[1].shape_dict[axis]
    A = op.inputs["A"]
    B = op.inputs["B"]
    C = op.outputs["C"]
    transpose_A, transpose_B = op.transpose_A, op.transpose_B
    M, K, N = op.M, op.K, op.N
    axis_M, axis_K, axis_N = Axis(None), Axis(None), Axis(None)

    op.remove_all()

    def decompose_logical_axes(logical_shape: Tuple[int, int], v: Variable):
        """
        Decompose logical axes into real axes

        Examples::

            A.order, A.shape
            >>> "NCHW", (1, 128, 8, 8)

            M = 128
            K = 64
            decompose_logical_axes([M, K], A)
            >>> ["<Axis N>", "<Axis C>"], ["<Axis H>", "<Axis W>"]
        """
        total_size = 1
        axes1 = []  # type: List[Axis]
        axes2 = list(v.order.axes)  # type: List[Axis]
        for size, a in zip(v.shape, v.order.axes):
            if total_size == logical_shape[0]:
                return axes1, axes2

            elif total_size > logical_shape[0]:
                raise ValueError

            axes1.append(a)
            axes2.remove(a)
            total_size *= size

    if v == A:
        A1, A2 = v_pair
        if transpose_A:  # A.shape = [M, K]
            axes_M, axes_K = decompose_logical_axes((M, K), A)

        else:  # A.shape = [K, M]
            axes_K, axes_M = decompose_logical_axes((K, M), A)

        if axis in axes_K:
            """
            before)

                A -{sgemm}- C

            after) In case `axis` is in `K`,

                A_0 -{sgemm}- C_0 -+
                                   +-{Add}- C
                A_1 -{sgemm}- C_1 -+
            """
            K1, K2 = K * s1 // (s1 + s2), K * s2 // (s1 + s2)

            # Factorize B's axes included in K into A's corresponding axes
            if transpose_B:  # B: [k_b1, k_b2, ..., N] -{reshape}-> [k_a1, k_a2, ..., N]
                B, = Reshape(None,
                             in_order=B.order,
                             out_order=Order(axes_K + [axis_N]),
                             out_shape=[A.shape_dict[a] for a in axes_K] + [N])(B)
            else:  # B: [N, k_b1, k_b2, ...] -{reshape}-> [N, k_a1, k_a2, ...]
                B, = Reshape(None,
                             in_order=B.order,
                             out_order=Order([axis_N] + axes_K),
                             out_shape=[N] + [A.shape_dict[a] for a in axes_K])(B)

            B1, B2 = SplitAxis(None, axis=axis, sections=[s1])(B)

            C1, = Sgemm(None, M=M, K=K1, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=op.parameters["out_shape"],
                        out_order=op.parameters["out_order"])(A1, B1)

            C2, = Sgemm(None, M=M, K=K2, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=op.parameters["out_shape"],
                        out_order=op.parameters["out_order"])(A2, B2)

            OptimizeRule.replace_variable(graph, C1 + C2, C)

        else:
            assert axis in axes_M
            """
            before)

                A -{sgemm}- C

            after) In case `axis` is in `M`,

                A_0 -{sgemm}- C_0 -+
                                   +-{Concat}- C
                A_1 -{sgemm}- C_1 -+
            """
            M1, M2 = M * s1 // (s1 + s2), M * s2 // (s1 + s2)

            c_tmp_order = Order(axes_M + [axis_N])
            c1_shape = [A1.shape_dict[a] for a in axes_M] + [N]
            c2_shape = [A2.shape_dict[a] for a in axes_M] + [N]

            C1, = Sgemm(None, M=M1, K=K, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=c1_shape,
                        out_order=c_tmp_order)(A1, B)

            C2, = Sgemm(None, M=M2, K=K, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=c2_shape,
                        out_order=c_tmp_order)(A2, B)

            C_new, = Concat(None, axis=axis)(C1, C2)
            C_new, = Reshape(None, in_order=c_tmp_order, out_order=C.order, out_shape=C.shape)(C_new)
            OptimizeRule.replace_variable(graph, C_new, C)

    elif v == B:
        B1, B2 = v_pair
        if transpose_B:  # B.shape = [K, N]
            axes_K, axes_N = decompose_logical_axes((K, N), B)

        else:  # B.shape = [N, K]
            axes_N, axes_K = decompose_logical_axes((N, K), B)

        if axis in axes_K:
            """
            before)
    
                B -{sgemm}- C
    
            after) In case `axis` is in `K`,
    
                B_0 -{sgemm}- C_0 -+
                                   +-{Add}- C
                B_1 -{sgemm}- C_1 -+
            """
            K1, K2 = K * s1 // (s1 + s2), K * s2 // (s1 + s2)

            # Factorize A's axes included in K into B's corresponding axes
            if transpose_A:  # A: [M, k_a1, k_a2, k_a3, ...] -{reshape}-> [M, k_b1, k_b2, ...]
                A, = Reshape(None,
                             in_order=A.order,
                             out_order=Order([axis_M] + axes_K),
                             out_shape=[M] + [B.shape_dict[a] for a in axes_K])(A)
            else:  # A: [k_a1, k_a2, k_a3, ..., M] -{reshape}-> [k_b1, k_b2, ..., M]
                A, = Reshape(None,
                             in_order=A.order,
                             out_order=Order(axes_K + [axis_M]),
                             out_shape=[B.shape_dict[a] for a in axes_K] + [M])(A)

            A1, A2 = SplitAxis(None, axis=axis, sections=[s1])(A)

            C1, = Sgemm(None, M=M, K=K1, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=op.parameters["out_shape"],
                        out_order=op.parameters["out_order"])(A1, B1)

            C2, = Sgemm(None, M=M, K=K2, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=op.parameters["out_shape"],
                        out_order=op.parameters["out_order"])(A2, B2)

            OptimizeRule.replace_variable(graph, C1 + C2, C)

        else:
            assert axis in axes_N
            """
            before)
    
                C[M, N] = A[M, K] @ B[K, N]
    
            after) In case `axis` is in `N`,
    
                C[M, N] = Concat(C1[M, N1], C2[M, N2])
                        = Concat(A[M, K] @ B1[K, N1], A[M, K] @ B2[K, N2]) 
            """
            N1, N2 = N * s1 // (s1 + s2), N * s2 // (s1 + s2)

            c_tmp_order = Order([axis_M] + axes_N)
            c1_shape = [M] + [B1.shape_dict[a] for a in axes_N]
            c2_shape = [M] + [B2.shape_dict[a] for a in axes_N]

            C1, = Sgemm(None, M=M, K=K, N=N1,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=c1_shape,
                        out_order=c_tmp_order)(A, B1)
            # C1.shape = [M, B.shape_dict[n1], B.shape_dict[n2], ..., B1.shape_dict[axis], ...]
            # C1.order = [axis_M, n1, n2, ..., axis, ...]

            C2, = Sgemm(None, M=M, K=K, N=N2,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=c2_shape,
                        out_order=c_tmp_order)(A, B2)

            C_new, = Concat(None, axis=axis)(C1, C2)
            # C_new.shape = [M, B.shape_dict[n1], B.shape_dict[n2], ..., B1.shape_dict[axis]+B2.shape_dict[axis], ...]
            # C_new.order = [axis_M, n1, n2, ..., axis, ...]

            C_new, = Reshape(None, in_order=c_tmp_order, out_order=C.order, out_shape=C.shape)(C_new)
            OptimizeRule.replace_variable(graph, C_new, C)

    elif v == C:
        """
        before)

            C[M, N] = A[M, K] @ B[K, N]

        after) In case `axis` is in `N`,

            C[M, N1] = Concat(A[M, K] @ B1[K, N1])
            C[M, N2] = Concat(A[M, K] @ B2[K, N2])
        """
        raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}")

    else:
        raise UnexpectedAndPleaseReportError
Example #21
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.listup_operators(graph):
            if isinstance(op, Reshape):
                flag_changed |= _replace_input(op, "x",
                                               op.parameters["in_order"])
                flag_changed |= _replace_output(op, "y",
                                                op.parameters["out_order"])
                continue

            elif isinstance(op, (Convolution2D, MaxPooling2D, AveragePooling2D,
                                 Deconvolution2D, Space2Depth, Depth2Space)):
                flag_changed |= _replace_input(op, "x", OrderNHWC)
                flag_changed |= _replace_output(op, "y", OrderNHWC)
                continue

            elif isinstance(op, Softmax):
                x = op.inputs["x"]
                y = op.outputs["y"]
                target_axis = op.parameters["axis"]

                if not (x.ndim == 2
                        and x.order.axes_dict[target_axis] == x.ndim - 1):
                    """
                    Before)
                    | x   |              | y   |
                    |-----| -{softmax}-> |-----|
                    | XYZ |   axis=Y     | XYZ |
                    
                    After)
                    | x   |                | hx1 |              | hx2 |              | hy1 |              | hy2 |                | y   |
                    |-----| -{transpose}-> |-----| -{reshape}-> |-----| -{softmax}-> |-----| -{reshape}-> |-----| -{transpose}-> |-----|
                    | XYZ |                | XZY |              | NC  |   axis=C     | NC  |              | XZY |                | XYZ |
                                              :                    :
                                        order_nd = XZY       order_2d = NC
                    """
                    op.remove_all()

                    axes_nd = list(x.order.axes)
                    axes_nd.remove(target_axis)
                    axes_nd.append(target_axis)
                    order_nd = Order(axes_nd)
                    shape_nd = tuple([x.shape_dict[axis] for axis in axes_nd])

                    order_2d = OrderNC
                    shape_2d = tuple([
                        x.size // x.shape_dict[target_axis],
                        x.shape_dict[target_axis]
                    ])

                    if x.order == order_nd:
                        hx1 = x

                    else:
                        hx1, = Transpose(None)(x)
                        hx1.change_order(order_nd)
                        flag_changed = True

                    if hx1.order == order_2d and hx1.shape == shape_2d:
                        hx2 = hx1

                    else:
                        hx2, = Reshape(None,
                                       in_order=hx1.order,
                                       out_order=order_2d,
                                       out_shape=shape_2d)(hx1)
                        flag_changed = True

                    hy1, = Softmax(None, axis=Axis.C)(hx2)

                    if hy1.order == order_nd and hy1.shape == shape_nd:
                        hy2 = hy1

                    else:
                        hy2, = Reshape(None,
                                       in_order=hy1.order,
                                       out_order=order_nd,
                                       out_shape=shape_nd)(hy1)
                        flag_changed = True

                    if hy2.order == y.order:
                        y_dummy = hy2

                    else:
                        y_dummy, = Transpose(None)(hy2)
                        y_dummy.change_order(y.order)
                        flag_changed = True

                    y_dummy.replace(y)

                    continue

        return graph, flag_changed
Example #22
0
def _split_reshape(graph: Graph, op: Reshape, v: Variable, v_pair: Sequence[Variable], axis: Axis):
    x = op.inputs["x"]
    y = op.outputs["y"]
    s1 = v_pair[0].shape_dict[axis]
    s2 = v_pair[1].shape_dict[axis]
    op.remove_all()

    if v == x:
        """
        Regard x's order as `[D1, D2]`, shape as `[d1x, d2x]`, where the most outside axis in D2 is the split target axis.
        If y's shape can be converted as `[d1x, d2x]` by merging some adjacent axes in y, split can be performed.

        before)

            x -{reshape}- y

        after)

            x_0 -{reshape}- y_0 -+
                                 +-{concat[axis]}- y
            x_1 -{reshape}- y_1 -+
        """

        x_0, x_1 = v_pair
        d2x = mul(x.shape[x.order.axes_dict[axis]:])
        d2y = 1
        for axis_y in reversed(y.order.axes):
            d2y *= y.shape_dict[axis_y]

            if d2y == d2x:
                y_0_shape = [y.shape_dict[axis_y] * s1 // (s1 + s2) if a == axis_y else y.shape_dict[a] for a in y.order.axes]
                y_1_shape = [y.shape_dict[axis_y] * s2 // (s1 + s2) if a == axis_y else y.shape_dict[a] for a in y.order.axes]

                y_0 = x_0.reshape(y_0_shape, y.order)
                y_1 = x_1.reshape(y_1_shape, y.order)

                y_new, = Concat(None, axis=axis_y)(y_0, y_1)
                OptimizeRule.replace_variable(graph, y_new, y)
                break

            elif d2y > (s1 + s2) * d2x:
                raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}")

    elif v == y:
        """
        Same algorithm in case `v == y` (above).

        before)

            x -{reshape}- y

        after)

                       +- x_0 -{reshape}- y_0
            x -{split}-+
                       +- x_1 -{reshape}- y_1
        """

        y_0, y_1 = v_pair
        d2y = mul(y.shape[y.order.axes_dict[axis]:])
        d2x = 1
        for axis_x in reversed(x.order.axes):
            d2x *= x.shape_dict[axis_x]

            if d2x == d2y:
                x_0, x_1 = SplitAxis(None, axis=axis_x, sections=[x.shape_dict[axis_x] * s1 // (s1 + s2)])(x)

                OptimizeRule.replace_variable(graph, x_0.reshape_like(y_0), y_0)
                OptimizeRule.replace_variable(graph, x_1.reshape_like(y_1), y_1)
                break

            elif d2y > (s1 + s2) * d2x:
                raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}")

    else:
        raise UnexpectedAndPleaseReportError