Exemplo n.º 1
0
def test_mix_order():
    vx1 = np.random.rand(2, 3, 4, 5)
    vx2 = np.random.rand(2, 3, 4, 5)
    vx3 = np.random.rand(2, 3, 4, 5)
    vx4 = np.random.rand(2, 3, 4, 5)
    vy = np.concatenate((vx1, vx2, vx3, vx4), 1)

    y = Variable(vy.shape, order=OrderNHWC)
    x1, x2, x3, x4, = SplitAxis(None, axis=Axis.H, sections=[3, 6, 9])(y)

    x2.change_order(OrderCNHW)
    vx2 = np.rollaxis(vx2, 3, 0)

    x3.change_order(OrderCHWN)
    vx3 = np.rollaxis(np.rollaxis(vx3, 3, 0), 1, 4)

    x4.change_order(OrderNCHW)
    vx4 = np.rollaxis(vx4, 3, 1)

    generate_kernel_test_case(description=f"SplitAxis with mix order",
                              graph=Graph([y], [x1, x2, x3, x4]),
                              inputs={y: vy},
                              expected={
                                  x1: vx1,
                                  x2: vx2,
                                  x3: vx3,
                                  x4: vx4
                              })
Exemplo n.º 2
0
def _convert_split_axis(converter: ChainerConverter, c_op: "chainer.functions.SplitAxis"):
    x = converter.get_variable(c_op.inputs[0])

    if isinstance(c_op.indices_or_sections, int):
        raise NotImplementedError("[ChainerConverter] SplitAxis with indices are not supported.")

    ys = SplitAxis(None, sections=c_op.indices_or_sections, axis=x.order.axes[c_op.axis])(x)
    for i, y in enumerate(ys):
        converter.set_variable(c_op.outputs[i](), y)
Exemplo n.º 3
0
def _split_tensorwise(graph: Graph, op: Operator, v: Variable,
                      v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    s2 = v_pair[1].shape_dict[axis]
    xs = dict(op.inputs)
    ys = dict(op.outputs)
    op.remove_all()

    op_0 = op.copy()
    op_1 = op.copy()

    for key, x in xs.items():
        if x == v:
            x_0, x_1 = v_pair

        else:
            if axis in x.order.axes:
                x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x)

            else:
                # splitting is not occurred
                x_0 = x_1 = x

        op_0.append_input(key, x_0)
        op_1.append_input(key, x_1)

    for key, y in ys.items():
        if y == v:
            y_0, y_1 = v_pair

        else:
            if axis in y.order.axes:
                # TODO (Kiikurage)
                # Attribute attached to "y" is not copied to neither "y_0" or "y_1"
                y_0 = Variable([
                    s1 if a == axis else y.shape_dict[a] for a in y.order.axes
                ], y.order)
                y_1 = Variable([
                    s2 if a == axis else y.shape_dict[a] for a in y.order.axes
                ], y.order)
                y_new, = Concat(None, axis=axis)(y_0, y_1)
                OptimizeRule.replace_variable(graph, y, y_new)

            else:
                raise UnexpectedAndPleaseReportError

        op_0.append_output(key, y_0)
        op_1.append_output(key, y_1)
Exemplo n.º 4
0
    def optimize(self, graph):
        flag_changed = False
        matches = traverse.search_sub_structure(
            graph, [SplitAxis, Variable, SplitAxis])

        while len(matches) > 0:
            op1, h, op2 = matches.pop()  # type: SplitAxis, Variable, SplitAxis

            if len(h.input_to) > 1:
                # `h` will be removed by this optimization
                continue

            if op1.axis != op2.axis:
                # These operations cannot be merged.
                continue

            flag_changed = True
            x = op1.inputs["x"]

            hs = [op1.outputs[f"y{i}"] for i in range(len(op1.outputs))]
            i_h = hs.index(h)

            original_ys = list(hs)
            new_sections = op1.sections

            original_ys.remove(h)
            section_offset = ([0] + op1.sections)[i_h]
            op2_sections = [0] + op2.sections
            for i in range(len(op2.outputs)):
                original_ys.insert(i_h + i, op2.outputs[f"y{i}"])
                new_sections.insert(i_h + i, section_offset + op2_sections[i])

            new_sections.remove(section_offset)

            op1.remove_all()
            op2.remove_all()

            new_ys = SplitAxis(None, axis=op1.axis, sections=new_sections)(x)

            for original_y, new_y in zip(original_ys, new_ys):
                new_y.change_order(original_y.order)
                new_y.replace(original_y)

            matches = traverse.search_sub_structure(
                graph, [SplitAxis, Variable, SplitAxis])

        return graph, flag_changed
Exemplo n.º 5
0
def _convert_split(converter: ONNXConverter, onnx_op: INodeProto):
    x = converter.get_variable(onnx_op.input[0])

    attrs = attribute_dict(onnx_op)

    axis = x.order.axes[attrs["axis"].i]

    if "split" not in attrs:
        raise NotImplementedError(
            "[ONNXConverter] Operator \"Split\" without \"split\" parameter is not supported yet."
        )
    split = attrs["split"].ints
    sections = np.cumsum(split).tolist()[:-1]

    ys = SplitAxis(None, axis=axis, sections=sections)(x)
    for i, y in enumerate(ys):
        converter.set_variable(onnx_op.output[i], y)
Exemplo n.º 6
0
def test_middle_axis():
    vx1 = np.random.rand(2, 3, 4, 5)
    vx2 = np.random.rand(2, 3, 4, 5)
    vx3 = np.random.rand(2, 3, 4, 5)
    vx4 = np.random.rand(2, 3, 4, 5)
    vy = np.concatenate((vx1, vx2, vx3, vx4), 1)

    y = Variable(vy.shape, order=OrderNHWC)
    x1, x2, x3, x4, = SplitAxis(None, axis=Axis.H, sections=[3, 6, 9])(y)

    generate_kernel_test_case(description=f"SplitAxis in middle axis",
                              graph=Graph([y], [x1, x2, x3, x4]),
                              inputs={y: vy},
                              expected={
                                  x1: vx1,
                                  x2: vx2,
                                  x3: vx3,
                                  x4: vx4
                              })
Exemplo n.º 7
0
def test_2d():
    vx1 = np.random.rand(2, 3)
    vx2 = np.random.rand(2, 3)
    vx3 = np.random.rand(2, 3)
    vx4 = np.random.rand(2, 3)
    vy = np.concatenate((vx1, vx2, vx3, vx4), 0)

    y = Variable(vy.shape, order=OrderNC)
    x1, x2, x3, x4, = SplitAxis(None, axis=Axis.N, sections=[2, 4, 6])(y)

    generate_kernel_test_case(description=f"SplitAxis 2D",
                              graph=Graph([y], [x1, x2, x3, x4]),
                              inputs={y: vy},
                              expected={
                                  x1: vx1,
                                  x2: vx2,
                                  x3: vx3,
                                  x4: vx4
                              })
Exemplo n.º 8
0
def _split_tensorwise(graph: Graph, op: Operator, v: Variable,
                      v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    xs = dict(op.inputs)
    ys = dict(op.outputs)
    op.remove_all()

    op_0 = op.copy()
    op_1 = op.copy()

    for key in xs.keys():
        x = xs[key]
        if x == v:
            x_0, x_1 = v_pair

        else:
            if axis not in x.order.axes or x.shape_dict[axis] == 1:
                # broadcasting
                x_0 = x_1 = x

            else:
                x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x)

        op_0.append_input(key, x_0)
        op_1.append_input(key, x_1)

    op_0.exec()
    op_1.exec()

    for key in ys.keys():
        y = ys[key]
        if y == v:
            OptimizeRule.replace_variable(
                graph, op_0.outputs[key].transpose_like(v_pair[0]), v_pair[0])
            OptimizeRule.replace_variable(
                graph, op_1.outputs[key].transpose_like(v_pair[1]), v_pair[1])

        else:
            y_0 = op_0.outputs[key]
            y_1 = op_1.outputs[key]
            y_new, = Concat(None, axis=axis)(y_0, y_1)
            OptimizeRule.replace_variable(graph, y_new.transpose_like(y), y)
Exemplo n.º 9
0
def test_minor_axis():
    vx1 = np.random.rand(2, 3, 4, 5)
    vx2 = np.random.rand(2, 3, 4, 5)
    vx3 = np.random.rand(2, 3, 4, 5)
    vx4 = np.random.rand(2, 3, 4, 5)
    vy = np.concatenate((vx1, vx2, vx3, vx4), 3)

    y = Variable(vy.shape, order=OrderNHWC)
    x1, x2, x3, x4, = SplitAxis(None, axis=Axis.C, sections=[5, 10, 15])(y)

    generate_kernel_test_case(
        description=f"SplitAxis in minor axis",
        backend=["webgpu", "webassembly", "fallback"],
        graph=Graph([y], [x1, x2, x3, x4]),
        inputs={y: vy},
        expected={
            x1: vx1,
            x2: vx2,
            x3: vx3,
            x4: vx4
        }
    )
Exemplo n.º 10
0
def _convert_split_axis(converter: ChainerConverter,
                        c_op: "chainer.functions.SplitAxis"):
    x = converter.get_variable(c_op.inputs[0])

    VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH = semver(chainer.__version__)
    if VERSION_MAJOR >= 4:
        # Internal data structure changed
        # https://github.com/chainer/chainer/commit/906a8e9b0837cd9a4e5ee6f1dbda26431a1e12d1#diff-9e610d281c820d44c4a0cbf0ca6263fd
        if c_op.indices is None:
            raise NotImplementedError(
                "[ChainerConverter] SplitAxis with sections are not supported."
            )
        indices = c_op.indices
    else:
        if isinstance(c_op.indices_or_sections, int):
            raise NotImplementedError(
                "[ChainerConverter] SplitAxis with sections are not supported."
            )
        indices = c_op.indices_or_sections

    ys = SplitAxis(None, sections=indices, axis=x.order.axes[c_op.axis])(x)
    for i, y in enumerate(ys):
        converter.set_variable(c_op.outputs[i](), y)
Exemplo n.º 11
0
def _split_tensordot(graph: Graph, op: Tensordot, v: Variable,
                     v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    s2 = v_pair[1].shape_dict[axis]
    A = op.inputs["A"]
    B = op.inputs["B"]
    C = op.outputs["C"]
    axes_M = tuple(filter(lambda a: a not in op.axes[0], A.order.axes))
    axes_N = tuple(filter(lambda a: a not in op.axes[1], B.order.axes))

    axes_K_A, axes_K_B = op.axes

    K = mul(A.shape_dict[a] for a in axes_K_A)
    M = A.size // K
    N = B.size // K

    shape_M = [A.shape_dict[a] for a in axes_M]
    shape_N = [B.shape_dict[a] for a in axes_N]

    op.remove_all()

    if v == A:
        A1, A2 = v_pair

        if axis in axes_K_A:
            split_axis_A = axis

            if (B.shape_dict[axes_K_B[0]] * s1) % (s1 + s2) == 0:
                split_axis_B = axes_K_B[0]

            else:
                # Factorize B's axes consisting to K into A's corresponding axes
                B = B.transpose(Order(axes_N + axes_K_B))
                B = B.reshape(order=Order((Axis(), ) + axes_K_A),
                              shape=[N] + [A.shape_dict[a] for a in axes_K_A])
                split_axis_B = split_axis_A
                axes_K_B = axes_K_A

            B1, B2 = SplitAxis(None,
                               axis=split_axis_B,
                               sections=[(B.shape_dict[split_axis_B] * s1) //
                                         (s1 + s2)])(B)

            C1, = Tensordot(None, [axes_K_A, axes_K_B])(A1, B1)
            C2, = Tensordot(None, [axes_K_A, axes_K_B])(A2, B2)
            OptimizeRule.replace_variable(graph, (C1 + C2).reshape(
                shape_M + shape_N, Order(axes_M + axes_N)).transpose_like(C),
                                          C)

        else:
            C1, = Tensordot(None, op.axes)(A1, B)
            C2, = Tensordot(None, op.axes)(A2, B)

            for a1, a2 in zip(C1.order.axes, C2.order.axes):
                if a1 == a2 == axis:
                    continue
                a1.unify(a2)

            C_new, = Concat(None, axis=axis)(C1, C2)
            OptimizeRule.replace_variable(graph, C_new, C)

    elif v == B:
        B1, B2 = v_pair

        if axis in axes_K_B:
            split_axis_B = axis

            if (A.shape_dict[axes_K_A[0]] * (s1 + s2)) % s1 == 0:
                split_axis_A = axes_K_A[0]

            else:
                # Factorize A's axes consisting to K into B's corresponding axes
                A = A.transpose(Order(axes_M + axes_K_A))
                A = A.reshape(order=Order((Axis(), ) + axes_K_B),
                              shape=[M] + [B.shape_dict[a] for a in axes_K_B])
                split_axis_A = split_axis_B
                axes_K_A = axes_K_B

            A1, A2 = SplitAxis(None,
                               axis=split_axis_A,
                               sections=[(A.shape_dict[split_axis_A] * s1) //
                                         (s1 + s2)])(A)

            C1, = Tensordot(None, [axes_K_A, axes_K_B])(A1, B1)
            C2, = Tensordot(None, [axes_K_A, axes_K_B])(A2, B2)
            OptimizeRule.replace_variable(graph, (C1 + C2).reshape(
                shape_M + shape_N, Order(axes_M + axes_N)).transpose_like(C),
                                          C)

        else:
            C1, = Tensordot(None, op.axes)(A, B1)
            C2, = Tensordot(None, op.axes)(A, B2)

            for a1, a2 in zip(C1.order.axes, C2.order.axes):
                if a1 == a2 == axis:
                    continue
                a1.unify(a2)

            C_new, = Concat(None, axis=axis)(C1, C2)
            OptimizeRule.replace_variable(graph, C_new, C)

    elif v == C:
        """
        before)

            C[M, N] = A[M, K] @ B[K, N]

        after) In case `axis` is in `N`,

            C[M, N1] = Concat(A[M, K] @ B1[K, N1])
            C[M, N2] = Concat(A[M, K] @ B2[K, N2])
        """
        raise NotImplementedError(
            f"Variable is too large to handle in WebGL backend: {v}")

    else:
        raise UnexpectedAndPleaseReportError
Exemplo n.º 12
0
def _split_reshape(graph: Graph, op: Reshape, v: Variable,
                   v_pair: Sequence[Variable], axis: Axis):
    x = op.inputs["x"]
    y = op.outputs["y"]
    s1 = v_pair[0].shape_dict[axis]
    s2 = v_pair[1].shape_dict[axis]
    op.remove_all()
    in_order = op.in_order
    out_order = op.out_order
    x_shape = [x.shape_dict[a] for a in in_order.axes]
    y_shape = [y.shape_dict[a] for a in out_order.axes]

    if v == x:
        """
        before)

            x -{reshape}- y

        after)

            x_0 -{reshape}- t_0 -+
                                 +-{concat[axis_k]}- t -{reshape}- y
            x_1 -{reshape}- t_1 -+

        shape and order is changed as follows:

                  x.shape = [dx_0, dx_1, ..., dx_m,   ..., dx_M-1]
                x_0.shape = [dx_0, dx_1, ..., dx_m/2, ..., dx_M-1]
            ---------------------------------------------------------------------------------
                t_0.shape = [dy_0, dy_1, ..., dy_n,   ..., dy_k/2, ..., dy_N-1]
                  t.shape = [dy_0, dy_1, ..., dy_n*2, ..., dy_k/2, ..., dy_N-1]
                  y.shape = [dy_0, dy_1, ..., dy_n,   ..., dy_k,   ..., dy_N-1]

            m: split target axis

            find axis_k and axis_n, which satisfies follow conditions

                dy_n * dy_n+1 * ... * dy_N-1 == dx_m * dx_m+1 * ... * dx_M-1
                dy_k % 2 == 0
                n <= k
        """

        x_0, x_1 = v_pair
        dx_prod = mul(x_shape[in_order.axes_dict[axis]:])
        dy_prod = 1
        axis_k_candidate = []
        for axis_n in reversed(out_order.axes):
            dy_prod *= y.shape_dict[axis_n]
            if y.shape_dict[axis_n] % 2 == 0:
                axis_k_candidate.append(axis_n)

            if dx_prod == dy_prod:
                # Split most large axis
                axis_k = (sorted(axis_k_candidate,
                                 key=lambda a: y.shape_dict[a],
                                 reverse=True))[0]

                t_0_shape = [y.shape_dict[a] for a in out_order.axes]
                t_0_shape[out_order.axes_dict[axis_k]] = t_0_shape[
                    out_order.axes_dict[axis_k]] // 2  # TODO
                t_0, = Reshape(None,
                               in_order=in_order,
                               out_order=out_order,
                               out_shape=t_0_shape)(x_0)

                t_1_shape = [y.shape_dict[a] for a in out_order.axes]
                t_1_shape[out_order.axes_dict[axis_k]] = t_1_shape[
                    out_order.axes_dict[axis_k]] // 2  # TODO
                t_1, = Reshape(None,
                               in_order=in_order,
                               out_order=out_order,
                               out_shape=t_1_shape)(x_1)

                t, = Concat(None, axis=axis_n)(t_0, t_1)
                y_new, = Reshape(None,
                                 in_order=out_order,
                                 out_order=out_order,
                                 out_shape=y_shape)(t)

                OptimizeRule.replace_variable(graph, y_new.transpose_like(y),
                                              y)
                break

            elif dy_prod > (s1 + s2) * dx_prod:
                raise NotImplementedError(
                    f"Variable is too large to handle in WebGL backend: {v}")

    elif v == y:
        """
        algorithm is almost same as the case `v == x` (above).

        before)

            x -{reshape}- y

        after)

                                    +- t_0 -{reshape}- y_0
            x -{reshape}- t-{split}-+
                                    +- t_1 -{reshape}- y_1

        shape and order is changed as follows:

                  x.shape = [dx_0, dx_1, ..., dx_m,   ..., dx_k,   ..., dx_M-1]
                  t.shape = [dx_0, dx_1, ..., dx_m*2, ..., dx_k/2, ..., dx_M-1]
                t_0.shape = [dx_0, dx_1, ..., dx_m,   ..., dx_k/2, ..., dx_M-1]
            ---------------------------------------------------------------------------------
                y_0.shape = [dy_0, dy_1, ..., dy_n/2, ..., dy_N-1]
                  y.shape = [dy_0, dy_1, ..., dy_n,   ..., dy_N-1]

            m: split target axis

            find axis_k and axis_m, which satisfies follow conditions

                dx_m * dx_m+1 * ... * dx_M-1 == dy_n * dy_n+1 * ... * dy_N-1
                dx_k % 2 == 0
                m <= k
        """

        y_0, y_1 = v_pair
        dx_prod = 1
        dy_prod = mul(x_shape[out_order.axes_dict[axis]:])
        axis_k_candidate = []
        for axis_m in reversed(in_order.axes):
            dx_prod *= x.shape_dict[axis_m]
            if x.shape_dict[axis_m] % 2 == 0:
                axis_k_candidate.append(axis_m)

            if dx_prod == dy_prod:
                # Split most large axis
                axis_k = (sorted(axis_k_candidate,
                                 key=lambda a: x.shape_dict[a],
                                 reverse=True))[0]

                t_shape = [x.shape_dict[a] for a in in_order.axes]
                t_shape[in_order.axes_dict[axis_m]] = 2 * t_shape[
                    in_order.axes_dict[axis_m]]  # TODO
                t_shape[in_order.axes_dict[axis_k]] = t_shape[
                    in_order.axes_dict[axis_k]] // 2  # TODO
                t, = Reshape(None,
                             in_order=in_order,
                             out_order=in_order,
                             out_shape=t_shape)(x)

                t_0, t_1 = SplitAxis(None,
                                     axis=axis_m,
                                     sections=[t.shape_dict[axis_m] // 2
                                               ])(t)  # TODO

                y_0_new, = Reshape(None,
                                   in_order=in_order,
                                   out_order=out_order,
                                   out_shape=y_0.shape)(t_0)
                y_1_new, = Reshape(None,
                                   in_order=in_order,
                                   out_order=out_order,
                                   out_shape=y_1.shape)(t_1)

                OptimizeRule.replace_variable(graph, y_0_new.reshape_like(y_0),
                                              y_0)
                OptimizeRule.replace_variable(graph, y_1_new.reshape_like(y_1),
                                              y_1)
                break

            elif dx_prod > dy_prod:
                raise NotImplementedError(
                    f"Variable is too large to handle in WebGL backend: {v}")

    else:
        raise UnexpectedAndPleaseReportError
Exemplo n.º 13
0
def _split_splitaxis(graph: Graph, op: SplitAxis, v: Variable,
                     v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    x = op.inputs["x"]
    ys = [op.outputs[f"y{i}"] for i in range(len(op.outputs))]
    sections = op.parameters["sections"]
    op.remove_all()

    if v == x:
        x_0, x_1 = v_pair
        if axis == op.axis:
            """
            before)
                                      +-y1
                                      |
                x -{split[axis=axis]}-+-y2
                                      |
                                      +-y3

            after)
                                        +- y1
                x_0 -{split[axis=axis]}-+
                                        +- y2_0 -+
                                                 +-{concat[axis=axis]}- y2
                                        +- y2_1 -+
                x_1 -{split[axis=axis]}-+
                                        +- y3
            """
            # find output variable which should be split ("y2" in above figure)

            total_size = 0
            ys_0 = []  # type: List[Variable]
            ys_1 = list(ys)  # type: List[Variable]
            for y in ys:
                ys_1.remove(y)

                if total_size + y.shape_dict[axis] == s1:
                    # splitting is not needed
                    #
                    #       x_0        |       x_1
                    # <--------------> | <--------------->
                    # y0, y1, ..., yn, | yn+1, ..., ys[-1]
                    ys_0.append(y)
                    break

                elif total_size + y.shape_dict[axis] > s1:
                    # this `y` must be split
                    #
                    #         x_0         |         x_1
                    # <-----------------> | <----------------->
                    #  y0, y1, ..., yn_0, | yn_1, ..., ys[-1]
                    #               <----------->
                    #                     yn

                    yn_0 = Variable([
                        s1 - total_size if a == axis else y.shape_dict[a]
                        for a in y.order.axes
                    ], y.order)
                    yn_1 = Variable([
                        y.shape_dict[axis] -
                        (s1 - total_size) if a == axis else y.shape_dict[a]
                        for a in y.order.axes
                    ], y.order)
                    OptimizeRule.replace_variable(
                        graph,
                        Concat(None, axis=axis)(yn_0,
                                                yn_1)[0].change_order(y.order),
                        y)
                    ys_0.append(yn_0)
                    ys_1.insert(0, yn_1)
                    break

                else:
                    ys_0.append(y)
                    total_size += y.shape_dict[axis]

            if len(ys_0) > 1:
                sections_0 = [0]
                for y in ys_0:
                    sections_0.append(sections_0[-1] + y.shape_dict[axis])
                sections_0.pop(0)
                sections_0.pop()

                for y_new, y in zip(
                        SplitAxis(None, axis=axis, sections=sections_0)(x_0),
                        ys_0):
                    y_new.change_order(y.order)
                    OptimizeRule.replace_variable(graph, y_new, y)

            elif len(ys_0) == 1:
                OptimizeRule.replace_variable(graph, ys_0[0], x_0)

            else:
                raise UnexpectedAndPleaseReportError

            if len(ys_1) > 1:
                sections_1 = [0]
                for y in ys_1:
                    sections_1.append(sections_1[-1] + y.shape_dict[axis])
                sections_1.pop(0)
                sections_1.pop()

                for y_new, y in zip(
                        SplitAxis(None, axis=axis, sections=sections_1)(x_1),
                        ys_1):
                    y_new.change_order(y.order)
                    OptimizeRule.replace_variable(graph, y_new, y)

            elif len(ys_1) == 1:
                OptimizeRule.replace_variable(graph, ys_1[0], x_1)

            else:
                raise UnexpectedAndPleaseReportError

        else:
            """
            before)
                                         +- y1
                                         |
                x -{split[axis=op.axis]}-+- y2
                                         |
                                         +- y3

            after)
                                               +--- y1_0 -+
                                               |          +-{concat[axis=axis]}- y1
                                               | +- y1_1 -+
                                               | |
                    x_0 -{split[axis=op.axis]}-+-|- y2_0 -+
                                               | |        +-{concat[axis=axis]}- y2
                    x_1 -{split[axis=op.axis]}---+- y2_1 -+
                                               | |
                                               +-|- y3_0 -+
                                                 |        +-{concat[axis=axis]}- y3
                                                 +- y3_1 -+
            """
            ys_0 = SplitAxis(None, axis=op.axis, sections=op.sections)(x_0)
            ys_1 = SplitAxis(None, axis=op.axis, sections=op.sections)(x_1)

            for y, y_0, y_1 in zip(ys, ys_0, ys_1):
                y_new, = Concat(None, axis=axis)(y_0, y_1)
                OptimizeRule.replace_variable(graph, y_new, y)

    elif v in ys:
        op.remove_all()

        if axis == op.axis:
            """
            before)
                           +- y1
                           |
                x -{split}-+- y2
                           |
                           +- y3

            after)
                           +- y1
                           |
                           +- y2_0
                x -{split}-+
                           +- y2_1
                           |
                           +- y3
            """
            target_i = ys.index(v)

            s_insert = (0 if target_i == 0 else sections[target_i - 1]) + s1
            new_sections = list(sections)
            new_sections.insert(target_i, s_insert)

            new_ys = SplitAxis(None, axis=axis, sections=new_sections)(x)
            for i, new_y in enumerate(new_ys):
                if i == target_i:
                    ys.pop(0)
                    y = v_pair[0]
                    new_y.change_order(y.order)
                    OptimizeRule.replace_variable(graph, y, new_y)

                elif i == target_i + 1:
                    y = v_pair[1]
                    new_y.change_order(y.order)
                    OptimizeRule.replace_variable(graph, y, new_y)

                else:
                    y = ys.pop(0)
                    new_y.change_order(y.order)
                    OptimizeRule.replace_variable(graph, y, new_y)

        else:
            """
            before)

                 y1 y2 y3      y1   y2   y3
                +--+--+--+    +--+ +--+ +--+
                |  :  :  |    |  | |  | |  |
                |  :  :  | => |  | |  | |  |
                |  :  :  |    |  | |  | |  |
                +--+--+--+    +--+ +--+ +--+

                                    +- y1
                                    |
                x -{split[op.axis]}-+- y2
                                    |
                                    +- y3

            after) split y2 into y2_0 and y2_1

                                                y1_0 y2_0 y3_0         y2_0
                                  +--+--+--+    +--+ +--+ +--+     y1  +--+  y3
              0 +--+--+--+    x_0 |  :  :  |    |  | |  | |  |    +--+ |  | +--+
                |  :  :  |        +--+--+--+    +--+ +--+ +--+    |  | +--+ |  |
             s1 +  +  +  + =>                =>                => +  +      +  +
                |  :  :  |        +--+--+--+    +--+ +--+ +--+    |  | +--+ |  |
                +--+--+--+    x_1 |  :  :  |    |  | |  | |  |    +--+ |  | +--+
                                  +--+--+--+    +--+ +--+ +--+         +--+
                    x                           y1_1 y2_1 y3_1         y2_1

                                                          +--- y1_0 -+
                                                          |          +-{concat[axis]}- y1
                                                          | +- y1_1 -+
                                                          | |
                                 +- x_0 -{split[op.axis]}-+-|------------------------- y2_0
                x -{split[axis]}-+                        | |
                                 +- x_1 -{split[op.axis]}---+------------------------- y2_1
                                                          | |
                                                          +-|- y3_0 -+
                                                            |        +-{concat[axis]}- y3
                                                            +- y3_1 -+
            """
            x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x)
            ys_0, = SplitAxis(None, axis=op.axis, sections=op.sections)(x_0)
            ys_1, = SplitAxis(None, axis=op.axis, sections=op.sections)(x_1)
            for y, y_0, y_1 in zip(ys, ys_0, ys_1):
                if y == v:
                    OptimizeRule.replace_variable(graph, y_0, v_pair[0])
                    OptimizeRule.replace_variable(graph, y_1, v_pair[1])

                else:
                    y_new, = Concat(None, axis=axis)(y_0, y_1)
                    OptimizeRule.replace_variable(graph, y_new, y)

    else:
        raise UnexpectedAndPleaseReportError
Exemplo n.º 14
0
def _split_concat(graph: Graph, op: Concat, v: Variable,
                  v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    xs = [
        op.inputs[key] for key in sorted(
            [key for key in op.inputs.keys() if key.startswith("x")])
    ]
    y = op.outputs["y"]
    op.remove_all()

    if v in xs:
        x_0, x_1 = v_pair

        if axis == op.axis:
            """
            before)
                x1 -+
                    |
                x2 -+-{concat}- y
                    |
                x3 -+

            after)
                x1 ---+
                      |
                x2_0 -+
                      +-{concat}- y
                x2_1 -+
                      |
                x3 ---+
            """
            i = xs.index(v)
            xs.pop(i)
            xs.insert(i + 0, x_0)
            xs.insert(i + 1, x_1)

            y_new, = Concat(None, axis=axis)(*xs)
            OptimizeRule.replace_variable(graph, y, y_new)

        else:
            """
            before)
                x1 -+
                    |
                x2 -+-{concat[op.axis]}- y
                    |
                x3 -+

            after)
                                  +- x1_0 -+
                x1 -{split[axis]}-+        |
                                  +- x1_1 -|-+
                                           | |
                x2_0 ----------------------+---{concat[op.axis]}- y_0 -+
                                           | |                         +-{concat[axis]}- y
                x2_1 ----------------------|-+-{concat[op.axis]}- y_1 -+
                                           | |
                                  +- x3_0 -+ |
                x3 -{split[axis]}-+          |
                                  +- x3_1 ---+
            """
            xs_0, xs_1 = zip(*[
                v_pair if x == v else SplitAxis(None, axis=axis, sections=[s1])
                (x) for x in xs
            ])
            y_0, = Concat(None, axis=op.axis)(*xs_0)
            y_1, = Concat(None, axis=op.axis)(*xs_1)
            y_new, = Concat(None, axis=axis)(y_0, y_1)
            OptimizeRule.replace_variable(graph, y_new, y)

    elif v == y:
        y_0, y_1 = v_pair

        if axis == op.axis:
            """
            before)
                x1 -+
                    |
                x2 -+-{concat[axis=op.axis]}- y
                    |
                x3 -+

            after)
                x1 ------------------------------+
                                                 +-{concat[axis=axis]}- y_0
                                       +- y_0_1 -+
                x2 -{split[axis=axis]}-+
                                       +- y_1_0 -+
                                                 +-{concat[axis=axis]}- y_1
                x3 ------------------------------+
            """
            # find input variable which should be split

            total_size = 0
            xs_0 = []  # type: List[Variable]
            xs_1 = list(xs)  # type: List[Variable]
            for x in xs:
                xs_1.remove(x)
                xs_0.append(x)
                total_size += x.shape_dict[axis]

                if total_size == s1:
                    # splitting is not needed
                    #
                    # x0, x1, ..., xn, | xn+1, ..., xs[-1]
                    # <--------------> | <--------------->
                    #       y_0        |       y_1
                    break

                elif total_size > s1:
                    # this `x` must be split
                    #
                    #  x0, x1, ..., xn, ..., xs[-1]
                    # <-------------><------------->
                    #       y_0           y_1

                    xn_0, xn_1 = SplitAxis(
                        None,
                        axis=axis,
                        sections=[s1 - (total_size - x.shape_dict[axis])])(x)
                    xs_0.remove(x)
                    xs_0.append(xn_0)
                    xs_1.insert(0, xn_1)
                    break

            if len(xs_0) > 1:
                y_0, = Concat(None, axis=axis)(*xs_0)
                y_0.change_order(v_pair[0].order)

            elif len(xs_0) == 1:
                y_0 = xs_0[0]

            else:
                raise UnexpectedAndPleaseReportError

            if len(xs_1) > 1:
                y_1, = Concat(None, axis=axis)(*xs_1)
                y_1.change_order(v_pair[1].order)

            elif len(xs_1) == 1:
                y_1 = xs_1[0]

            else:
                raise UnexpectedAndPleaseReportError

            OptimizeRule.replace_variable(graph, y_0, v_pair[0])
            OptimizeRule.replace_variable(graph, y_1, v_pair[1])

        else:
            """
            before)
                x1 -+
                    |
                x2 -+-{concat[op.axis]}- y
                    |
                x3 -+

            after)
                                  +- x1_0 -+
                x1 -{split[axis]}-+        |
                                  +- x1_1 ---+
                                           | |
                                  +- x2_0 -+-|-{concat[op.axis]}- y_0
                x2 -{split[axis]}-+        | |
                                  +- x2_1 ---+-{concat[op.axis]}- y_1
                                           | |
                                  +- x3_0 -+ |
                x3 -{split[axis]}-+          |
                                  +- x3_1 ---+

            """
            xs_0, xs_1 = zip(
                *[SplitAxis(None, axis=axis, sections=[s1])(x) for x in xs])

            y_new_0, = Concat(None, axis=op.axis)(*xs_0)
            y_new_1, = Concat(None, axis=op.axis)(*xs_1)

            OptimizeRule.replace_variable(graph, y_new_0, y_0)
            OptimizeRule.replace_variable(graph, y_new_1, y_1)

    else:
        raise UnexpectedAndPleaseReportError
Exemplo n.º 15
0
def _split_reshape(graph: Graph, op: Reshape, v: Variable, v_pair: Sequence[Variable], axis: Axis):
    x = op.inputs["x"]
    y = op.outputs["y"]
    s1 = v_pair[0].shape_dict[axis]
    s2 = v_pair[1].shape_dict[axis]
    op.remove_all()

    if v == x:
        """
        Regard x's order as `[D1, D2]`, shape as `[d1x, d2x]`, where the most outside axis in D2 is the split target axis.
        If y's shape can be converted as `[d1x, d2x]` by merging some adjacent axes in y, split can be performed.

        before)

            x -{reshape}- y

        after)

            x_0 -{reshape}- y_0 -+
                                 +-{concat[axis]}- y
            x_1 -{reshape}- y_1 -+
        """

        x_0, x_1 = v_pair
        d2x = mul(x.shape[x.order.axes_dict[axis]:])
        d2y = 1
        for axis_y in reversed(y.order.axes):
            d2y *= y.shape_dict[axis_y]

            if d2y == d2x:
                y_0_shape = [y.shape_dict[axis_y] * s1 // (s1 + s2) if a == axis_y else y.shape_dict[a] for a in y.order.axes]
                y_1_shape = [y.shape_dict[axis_y] * s2 // (s1 + s2) if a == axis_y else y.shape_dict[a] for a in y.order.axes]

                y_0 = x_0.reshape(y_0_shape, y.order)
                y_1 = x_1.reshape(y_1_shape, y.order)

                y_new, = Concat(None, axis=axis_y)(y_0, y_1)
                OptimizeRule.replace_variable(graph, y_new, y)
                break

            elif d2y > (s1 + s2) * d2x:
                raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}")

    elif v == y:
        """
        Same algorithm in case `v == y` (above).

        before)

            x -{reshape}- y

        after)

                       +- x_0 -{reshape}- y_0
            x -{split}-+
                       +- x_1 -{reshape}- y_1
        """

        y_0, y_1 = v_pair
        d2y = mul(y.shape[y.order.axes_dict[axis]:])
        d2x = 1
        for axis_x in reversed(x.order.axes):
            d2x *= x.shape_dict[axis_x]

            if d2x == d2y:
                x_0, x_1 = SplitAxis(None, axis=axis_x, sections=[x.shape_dict[axis_x] * s1 // (s1 + s2)])(x)

                OptimizeRule.replace_variable(graph, x_0.reshape_like(y_0), y_0)
                OptimizeRule.replace_variable(graph, x_1.reshape_like(y_1), y_1)
                break

            elif d2y > (s1 + s2) * d2x:
                raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}")

    else:
        raise UnexpectedAndPleaseReportError
Exemplo n.º 16
0
def _split_sgemm(graph: Graph, op: Sgemm, v: Variable, v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    s2 = v_pair[1].shape_dict[axis]
    A = op.inputs["A"]
    B = op.inputs["B"]
    C = op.outputs["C"]
    transpose_A, transpose_B = op.transpose_A, op.transpose_B
    M, K, N = op.M, op.K, op.N
    axis_M, axis_K, axis_N = Axis(None), Axis(None), Axis(None)

    op.remove_all()

    def decompose_logical_axes(logical_shape: Tuple[int, int], v: Variable):
        """
        Decompose logical axes into real axes

        Examples::

            A.order, A.shape
            >>> "NCHW", (1, 128, 8, 8)

            M = 128
            K = 64
            decompose_logical_axes([M, K], A)
            >>> ["<Axis N>", "<Axis C>"], ["<Axis H>", "<Axis W>"]
        """
        total_size = 1
        axes1 = []  # type: List[Axis]
        axes2 = list(v.order.axes)  # type: List[Axis]
        for size, a in zip(v.shape, v.order.axes):
            if total_size == logical_shape[0]:
                return axes1, axes2

            elif total_size > logical_shape[0]:
                raise ValueError

            axes1.append(a)
            axes2.remove(a)
            total_size *= size

    if v == A:
        A1, A2 = v_pair
        if transpose_A:  # A.shape = [M, K]
            axes_M, axes_K = decompose_logical_axes((M, K), A)

        else:  # A.shape = [K, M]
            axes_K, axes_M = decompose_logical_axes((K, M), A)

        if axis in axes_K:
            """
            before)

                A -{sgemm}- C

            after) In case `axis` is in `K`,

                A_0 -{sgemm}- C_0 -+
                                   +-{Add}- C
                A_1 -{sgemm}- C_1 -+
            """
            K1, K2 = K * s1 // (s1 + s2), K * s2 // (s1 + s2)

            # Factorize B's axes included in K into A's corresponding axes
            if transpose_B:  # B: [k_b1, k_b2, ..., N] -{reshape}-> [k_a1, k_a2, ..., N]
                B, = Reshape(None,
                             in_order=B.order,
                             out_order=Order(axes_K + [axis_N]),
                             out_shape=[A.shape_dict[a] for a in axes_K] + [N])(B)
            else:  # B: [N, k_b1, k_b2, ...] -{reshape}-> [N, k_a1, k_a2, ...]
                B, = Reshape(None,
                             in_order=B.order,
                             out_order=Order([axis_N] + axes_K),
                             out_shape=[N] + [A.shape_dict[a] for a in axes_K])(B)

            B1, B2 = SplitAxis(None, axis=axis, sections=[s1])(B)

            C1, = Sgemm(None, M=M, K=K1, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=op.parameters["out_shape"],
                        out_order=op.parameters["out_order"])(A1, B1)

            C2, = Sgemm(None, M=M, K=K2, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=op.parameters["out_shape"],
                        out_order=op.parameters["out_order"])(A2, B2)

            OptimizeRule.replace_variable(graph, C1 + C2, C)

        else:
            assert axis in axes_M
            """
            before)

                A -{sgemm}- C

            after) In case `axis` is in `M`,

                A_0 -{sgemm}- C_0 -+
                                   +-{Concat}- C
                A_1 -{sgemm}- C_1 -+
            """
            M1, M2 = M * s1 // (s1 + s2), M * s2 // (s1 + s2)

            c_tmp_order = Order(axes_M + [axis_N])
            c1_shape = [A1.shape_dict[a] for a in axes_M] + [N]
            c2_shape = [A2.shape_dict[a] for a in axes_M] + [N]

            C1, = Sgemm(None, M=M1, K=K, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=c1_shape,
                        out_order=c_tmp_order)(A1, B)

            C2, = Sgemm(None, M=M2, K=K, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=c2_shape,
                        out_order=c_tmp_order)(A2, B)

            C_new, = Concat(None, axis=axis)(C1, C2)
            C_new, = Reshape(None, in_order=c_tmp_order, out_order=C.order, out_shape=C.shape)(C_new)
            OptimizeRule.replace_variable(graph, C_new, C)

    elif v == B:
        B1, B2 = v_pair
        if transpose_B:  # B.shape = [K, N]
            axes_K, axes_N = decompose_logical_axes((K, N), B)

        else:  # B.shape = [N, K]
            axes_N, axes_K = decompose_logical_axes((N, K), B)

        if axis in axes_K:
            """
            before)
    
                B -{sgemm}- C
    
            after) In case `axis` is in `K`,
    
                B_0 -{sgemm}- C_0 -+
                                   +-{Add}- C
                B_1 -{sgemm}- C_1 -+
            """
            K1, K2 = K * s1 // (s1 + s2), K * s2 // (s1 + s2)

            # Factorize A's axes included in K into B's corresponding axes
            if transpose_A:  # A: [M, k_a1, k_a2, k_a3, ...] -{reshape}-> [M, k_b1, k_b2, ...]
                A, = Reshape(None,
                             in_order=A.order,
                             out_order=Order([axis_M] + axes_K),
                             out_shape=[M] + [B.shape_dict[a] for a in axes_K])(A)
            else:  # A: [k_a1, k_a2, k_a3, ..., M] -{reshape}-> [k_b1, k_b2, ..., M]
                A, = Reshape(None,
                             in_order=A.order,
                             out_order=Order(axes_K + [axis_M]),
                             out_shape=[B.shape_dict[a] for a in axes_K] + [M])(A)

            A1, A2 = SplitAxis(None, axis=axis, sections=[s1])(A)

            C1, = Sgemm(None, M=M, K=K1, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=op.parameters["out_shape"],
                        out_order=op.parameters["out_order"])(A1, B1)

            C2, = Sgemm(None, M=M, K=K2, N=N,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=op.parameters["out_shape"],
                        out_order=op.parameters["out_order"])(A2, B2)

            OptimizeRule.replace_variable(graph, C1 + C2, C)

        else:
            assert axis in axes_N
            """
            before)
    
                C[M, N] = A[M, K] @ B[K, N]
    
            after) In case `axis` is in `N`,
    
                C[M, N] = Concat(C1[M, N1], C2[M, N2])
                        = Concat(A[M, K] @ B1[K, N1], A[M, K] @ B2[K, N2]) 
            """
            N1, N2 = N * s1 // (s1 + s2), N * s2 // (s1 + s2)

            c_tmp_order = Order([axis_M] + axes_N)
            c1_shape = [M] + [B1.shape_dict[a] for a in axes_N]
            c2_shape = [M] + [B2.shape_dict[a] for a in axes_N]

            C1, = Sgemm(None, M=M, K=K, N=N1,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=c1_shape,
                        out_order=c_tmp_order)(A, B1)
            # C1.shape = [M, B.shape_dict[n1], B.shape_dict[n2], ..., B1.shape_dict[axis], ...]
            # C1.order = [axis_M, n1, n2, ..., axis, ...]

            C2, = Sgemm(None, M=M, K=K, N=N2,
                        transpose_A=transpose_A,
                        transpose_B=transpose_B,
                        out_shape=c2_shape,
                        out_order=c_tmp_order)(A, B2)

            C_new, = Concat(None, axis=axis)(C1, C2)
            # C_new.shape = [M, B.shape_dict[n1], B.shape_dict[n2], ..., B1.shape_dict[axis]+B2.shape_dict[axis], ...]
            # C_new.order = [axis_M, n1, n2, ..., axis, ...]

            C_new, = Reshape(None, in_order=c_tmp_order, out_order=C.order, out_shape=C.shape)(C_new)
            OptimizeRule.replace_variable(graph, C_new, C)

    elif v == C:
        """
        before)

            C[M, N] = A[M, K] @ B[K, N]

        after) In case `axis` is in `N`,

            C[M, N1] = Concat(A[M, K] @ B1[K, N1])
            C[M, N2] = Concat(A[M, K] @ B2[K, N2])
        """
        raise NotImplementedError(f"Variable is too large to handle in WebGL backend: {v}")

    else:
        raise UnexpectedAndPleaseReportError