コード例 #1
0
ファイル: concat_test.py プロジェクト: VislaLabs/webdnn-1
def test_mix_order():
    vx1 = np.random.rand(2, 3, 4, 5)
    vx2 = np.random.rand(2, 3, 4, 5)
    vx3 = np.random.rand(2, 3, 4, 5)
    vx4 = np.random.rand(2, 3, 4, 5)
    vy = np.concatenate((vx1, vx2, vx3, vx4), 1)

    x1 = Variable(vx1.shape, order=OrderNHWC)
    x2 = Variable(vx2.shape, order=OrderNHWC)
    x3 = Variable(vx3.shape, order=OrderNHWC)
    x4 = Variable(vx4.shape, order=OrderNHWC)

    x2.change_order(OrderCNHW)
    vx2 = np.rollaxis(vx2, 3, 0)

    x3.change_order(OrderCHWN)
    vx3 = np.rollaxis(np.rollaxis(vx3, 3, 0), 1, 4)

    x4.change_order(OrderNCHW)
    vx4 = np.rollaxis(vx4, 3, 1)

    y, = Concat(None, axis=Axis.H)(x1, x2, x3, x4)
    y.change_order(OrderNHWC)

    generate_kernel_test_case(description=f"concat_mix_order",
                              graph=Graph([x1, x2, x3, x4], [y]),
                              inputs={
                                  x1: vx1,
                                  x2: vx2,
                                  x3: vx3,
                                  x4: vx4
                              },
                              expected={y: vy})
コード例 #2
0
    def optimize_pair(self, graph: Graph, op1: Concat, op2: ElementwiseMul):
        x0, x1 = op1.inputs["x0"], op1.inputs["x1"]
        c, _ = _get_constant_and_variable(op2, "x0", "x1")
        if c is None:
            return False
        if c.order != Order([op1.axis]):
            return False

        y2 = op2.outputs["y"]
        c0 = ConstantVariable(c.data[:x0.shape_dict[op1.axis]], c.order)
        c1 = ConstantVariable(c.data[x0.shape_dict[op1.axis]:], c.order)

        op1.remove_all()
        op2.remove_all()

        y, = Concat(None, axis=op1.axis)((x0 * c0), (x1 * c1))
        OptimizeRule.replace_variable(graph, y2, y.change_order(y2.order))
        return True
コード例 #3
0
ファイル: split_variable.py プロジェクト: zhangaz1/webdnn
def _split_concat(graph: Graph, op: Concat, v: Variable,
                  v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    xs = [
        op.inputs[key] for key in sorted(
            [key for key in op.inputs.keys() if key.startswith("x")])
    ]
    y = op.outputs["y"]
    op.remove_all()

    if v in xs:
        x_0, x_1 = v_pair

        if axis == op.axis:
            """
            before)
                x1 -+
                    |
                x2 -+-{concat}- y
                    |
                x3 -+

            after)
                x1 ---+
                      |
                x2_0 -+
                      +-{concat}- y
                x2_1 -+
                      |
                x3 ---+
            """
            i = xs.index(v)
            xs.pop(i)
            xs.insert(i + 0, x_0)
            xs.insert(i + 1, x_1)

            y_new, = Concat(None, axis=axis)(*xs)
            OptimizeRule.replace_variable(graph, y, y_new)

        else:
            """
            before)
                x1 -+
                    |
                x2 -+-{concat[op.axis]}- y
                    |
                x3 -+

            after)
                                  +- x1_0 -+
                x1 -{split[axis]}-+        |
                                  +- x1_1 -|-+
                                           | |
                x2_0 ----------------------+---{concat[op.axis]}- y_0 -+
                                           | |                         +-{concat[axis]}- y
                x2_1 ----------------------|-+-{concat[op.axis]}- y_1 -+
                                           | |
                                  +- x3_0 -+ |
                x3 -{split[axis]}-+          |
                                  +- x3_1 ---+
            """
            xs_0, xs_1 = zip(*[
                v_pair if x == v else SplitAxis(None, axis=axis, sections=[s1])
                (x) for x in xs
            ])
            y_0, = Concat(None, axis=op.axis)(*xs_0)
            y_1, = Concat(None, axis=op.axis)(*xs_1)
            y_new, = Concat(None, axis=axis)(y_0, y_1)
            OptimizeRule.replace_variable(graph, y_new, y)

    elif v == y:
        y_0, y_1 = v_pair

        if axis == op.axis:
            """
            before)
                x1 -+
                    |
                x2 -+-{concat[axis=op.axis]}- y
                    |
                x3 -+

            after)
                x1 ------------------------------+
                                                 +-{concat[axis=axis]}- y_0
                                       +- y_0_1 -+
                x2 -{split[axis=axis]}-+
                                       +- y_1_0 -+
                                                 +-{concat[axis=axis]}- y_1
                x3 ------------------------------+
            """
            # find input variable which should be split

            total_size = 0
            xs_0 = []  # type: List[Variable]
            xs_1 = list(xs)  # type: List[Variable]
            for x in xs:
                xs_1.remove(x)
                xs_0.append(x)
                total_size += x.shape_dict[axis]

                if total_size == s1:
                    # splitting is not needed
                    #
                    # x0, x1, ..., xn, | xn+1, ..., xs[-1]
                    # <--------------> | <--------------->
                    #       y_0        |       y_1
                    break

                elif total_size > s1:
                    # this `x` must be split
                    #
                    #  x0, x1, ..., xn, ..., xs[-1]
                    # <-------------><------------->
                    #       y_0           y_1

                    xn_0, xn_1 = SplitAxis(
                        None,
                        axis=axis,
                        sections=[s1 - (total_size - x.shape_dict[axis])])(x)
                    xs_0.remove(x)
                    xs_0.append(xn_0)
                    xs_1.insert(0, xn_1)
                    break

            if len(xs_0) > 1:
                y_0, = Concat(None, axis=axis)(*xs_0)
                y_0.change_order(v_pair[0].order)

            elif len(xs_0) == 1:
                y_0 = xs_0[0]

            else:
                raise UnexpectedAndPleaseReportError

            if len(xs_1) > 1:
                y_1, = Concat(None, axis=axis)(*xs_1)
                y_1.change_order(v_pair[1].order)

            elif len(xs_1) == 1:
                y_1 = xs_1[0]

            else:
                raise UnexpectedAndPleaseReportError

            OptimizeRule.replace_variable(graph, y_0, v_pair[0])
            OptimizeRule.replace_variable(graph, y_1, v_pair[1])

        else:
            """
            before)
                x1 -+
                    |
                x2 -+-{concat[op.axis]}- y
                    |
                x3 -+

            after)
                                  +- x1_0 -+
                x1 -{split[axis]}-+        |
                                  +- x1_1 ---+
                                           | |
                                  +- x2_0 -+-|-{concat[op.axis]}- y_0
                x2 -{split[axis]}-+        | |
                                  +- x2_1 ---+-{concat[op.axis]}- y_1
                                           | |
                                  +- x3_0 -+ |
                x3 -{split[axis]}-+          |
                                  +- x3_1 ---+

            """
            xs_0, xs_1 = zip(
                *[SplitAxis(None, axis=axis, sections=[s1])(x) for x in xs])

            y_new_0, = Concat(None, axis=op.axis)(*xs_0)
            y_new_1, = Concat(None, axis=op.axis)(*xs_1)

            OptimizeRule.replace_variable(graph, y_new_0, y_0)
            OptimizeRule.replace_variable(graph, y_new_1, y_1)

    else:
        raise UnexpectedAndPleaseReportError
コード例 #4
0
ファイル: split_variable.py プロジェクト: fossabot/hash2face
def _split_splitaxis(graph: Graph, op: SplitAxis, v: Variable, v_pair: Sequence[Variable], axis: Axis):
    s1 = v_pair[0].shape_dict[axis]
    x = op.inputs["x"]
    ys = [op.outputs[f"y{i}"] for i in range(len(op.outputs))]
    sections = op.parameters["sections"]
    op.remove_all()

    if v == x:
        x_0, x_1 = v_pair
        if axis == op.axis:
            """
            before)
                                      +-y1
                                      |
                x -{split[axis=axis]}-+-y2
                                      |
                                      +-y3

            after)
                                        +- h1 ------------------------ y1
                x_0 -{split[axis=axis]}-+
                                        +- h2_0 -+
                                                 +-{concat[axis=axis]}- y2
                                        +- h2_1 -+
                x_1 -{split[axis=axis]}-+
                                        +- h3 ------------------------- y3
            """
            # find output variable which should be split ("y2" in above figure)

            total_size = 0
            ys_0 = []  # type: List[Variable]
            ys_1 = list(ys)  # type: List[Variable]
            for y in ys:
                ys_1.remove(y)
                ys_0.append(y)
                total_size += y.shape_dict[axis]

                if total_size == s1:
                    # splitting is not needed
                    #
                    #       x_0        |       x_1
                    # <--------------> | <--------------->
                    # h0, h1, ..., hn, | hn+1, ..., hs[-1]
                    # y0, y1, ..., yn, | yn+1, ..., ys[-1]
                    break

                elif total_size > s1:
                    # this `y` must be split
                    #
                    #         x_0           |         x_1
                    # <-------------------> | <----------------->
                    #  h0, h1, ..., | hn_0, | hn_1, | ..., hs[-1]
                    #               | <-----------> |
                    #  y0, y1, ..., |     yn      , | ..., ys[-1]

                    hn_0 = Variable([x_0.shape_dict[axis] - (total_size - s1) if a == axis else y.shape_dict[a] for a in y.order.axes],
                                    y.order)
                    hn_1 = Variable([total_size - s1 if a == axis else y.shape_dict[a] for a in y.order.axes], y.order)
                    yn_new, = Concat(None, axis=axis)(hn_0, hn_1)
                    yn_new.change_order(y.order)
                    OptimizeRule.replace_variable(graph, yn_new, y)
                    ys_0.remove(y)
                    ys_0.append(hn_0)
                    ys_1.insert(0, hn_1)
                    break

            if len(ys_0) > 1:
                sections_0 = [0]
                for h in ys_0:
                    sections_0.append(sections_0[-1] + h.shape_dict[axis])
                sections_0.pop(0)
                sections_0.pop()

                for y_new, y in zip(SplitAxis(None, axis=axis, sections=sections_0)(x_0), ys_0):
                    y_new.change_order(y.order)
                    OptimizeRule.replace_variable(graph, y_new, y)

            elif len(ys_0) == 1:
                OptimizeRule.replace_variable(graph, ys_0[0], x_0)

            else:
                raise UnexpectedAndPleaseReportError

            if len(ys_1) > 1:
                sections_1 = [0]
                for h in ys_1:
                    sections_1.append(sections_1[-1] + h.shape_dict[axis])
                sections_1.pop(0)
                sections_1.pop()

                for y_new, y in zip(SplitAxis(None, axis=axis, sections=sections_1)(x_1), ys_1):
                    y_new.change_order(y.order)
                    OptimizeRule.replace_variable(graph, y_new, y)

            elif len(ys_1) == 1:
                OptimizeRule.replace_variable(graph, ys_1[0], x_1)

            else:
                raise UnexpectedAndPleaseReportError

        else:
            """
            before)
                                         +- y1
                                         |
                x -{split[axis=op.axis]}-+- y2
                                         |
                                         +- y3

            after)
                                               +--- y1_0 -+
                                               |          +-{concat[axis=axis]}- y1
                                               | +- y1_1 -+
                                               | |
                    x_0 -{split[axis=op.axis]}-+-|- y2_0 -+
                                               | |        +-{concat[axis=axis]}- y2
                    x_1 -{split[axis=op.axis]}---+- y2_1 -+
                                               | |
                                               +-|- y3_0 -+
                                                 |        +-{concat[axis=axis]}- y3
                                                 +- y3_1 -+
            """
            ys_0 = SplitAxis(None, axis=op.axis, sections=op.sections)(x_0)
            ys_1 = SplitAxis(None, axis=op.axis, sections=op.sections)(x_1)

            for y, y_0, y_1 in zip(ys, ys_0, ys_1):
                y_new, = Concat(None, axis=axis)(y_0, y_1)
                OptimizeRule.replace_variable(graph, y_new, y)

    elif v in ys:
        op.remove_all()

        if axis == op.axis:
            """
            before)
                           +- y1
                           |
                x -{split}-+- y2
                           |
                           +- y3

            after)
                           +- y1
                           |
                           +- y2_0
                x -{split}-+
                           +- y2_1
                           |
                           +- y3
            """
            target_i = ys.index(v)

            s_insert = (0 if target_i == 0 else sections[target_i - 1]) + s1
            new_sections = list(sections)
            new_sections.insert(target_i, s_insert)

            new_ys = SplitAxis(None, axis=axis, sections=new_sections)(x)
            for i, new_y in enumerate(new_ys):
                if i == target_i:
                    ys.pop(0)
                    y = v_pair[0]
                    new_y.change_order(y.order)
                    OptimizeRule.replace_variable(graph, y, new_y)

                elif i == target_i + 1:
                    y = v_pair[1]
                    new_y.change_order(y.order)
                    OptimizeRule.replace_variable(graph, y, new_y)

                else:
                    y = ys.pop(0)
                    new_y.change_order(y.order)
                    OptimizeRule.replace_variable(graph, y, new_y)

        else:
            """
            before)

                 y1 y2 y3      y1   y2   y3
                +--+--+--+    +--+ +--+ +--+
                |  :  :  |    |  | |  | |  |
                |  :  :  | => |  | |  | |  |
                |  :  :  |    |  | |  | |  |
                +--+--+--+    +--+ +--+ +--+

                                    +- y1
                                    |
                x -{split[op.axis]}-+- y2
                                    |
                                    +- y3

            after) split y2 into y2_0 and y2_1

                                                y1_0 y2_0 y3_0         y2_0
                                  +--+--+--+    +--+ +--+ +--+     y1  +--+  y3
              0 +--+--+--+    x_0 |  :  :  |    |  | |  | |  |    +--+ |  | +--+
                |  :  :  |        +--+--+--+    +--+ +--+ +--+    |  | +--+ |  |
             s1 +  +  +  + =>                =>                => +  +      +  +
                |  :  :  |        +--+--+--+    +--+ +--+ +--+    |  | +--+ |  |
                +--+--+--+    x_1 |  :  :  |    |  | |  | |  |    +--+ |  | +--+
                                  +--+--+--+    +--+ +--+ +--+         +--+
                    x                           y1_1 y2_1 y3_1         y2_1

                                                          +--- y1_0 -+
                                                          |          +-{concat[axis]}- y1
                                                          | +- y1_1 -+
                                                          | |
                                 +- x_0 -{split[op.axis]}-+-|------------------------- y2_0
                x -{split[axis]}-+                        | |
                                 +- x_1 -{split[op.axis]}---+------------------------- y2_1
                                                          | |
                                                          +-|- y3_0 -+
                                                            |        +-{concat[axis]}- y3
                                                            +- y3_1 -+
            """
            x_0, x_1 = SplitAxis(None, axis=axis, sections=[s1])(x)
            ys_0, = SplitAxis(None, axis=op.axis, sections=op.sections)(x_0)
            ys_1, = SplitAxis(None, axis=op.axis, sections=op.sections)(x_1)
            for y, y_0, y_1 in zip(ys, ys_0, ys_1):
                if y == v:
                    OptimizeRule.replace_variable(graph, y_0, v_pair[0])
                    OptimizeRule.replace_variable(graph, y_1, v_pair[1])

                else:
                    y_new, = Concat(None, axis=axis)(y_0, y_1)
                    OptimizeRule.replace_variable(graph, y_new, y)

    else:
        raise UnexpectedAndPleaseReportError