Exemplo n.º 1
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.filter_nodes(traverse.listup_operators(graph),
                                        Convolution2D):  # type: Convolution2D
            x = op.inputs["x"]
            w = op.inputs["w"]
            y = op.outputs["y"]
            flag_changed = True
            op.remove_all()

            a_filter, a_kh, a_kw = Axis(), Axis(), Axis()
            w, = ReinterpretAxis(None,
                                 in_order=OrderNHWC,
                                 out_order=Order(
                                     [Axis.C, a_kh, a_kw, a_filter]))(w)

            if op.WH == 1 and op.WW == 1 and op.stride == (
                    1, 1) and op.padding == (0, 0):
                # Projection
                col, = ReinterpretAxis(
                    None,
                    in_order=OrderNHWC,
                    out_order=Order([Axis.N, Axis.H, Axis.W, a_filter]))(x)

                new_y, = Tensordot(None,
                                   [[a_filter], [a_kh, a_kw, a_filter]])(col,
                                                                         w)

            elif op.WH == x.shape_dict[Axis.H] and op.WW == x.shape_dict[
                    Axis.W] and op.padding == (0, 0):
                # Global convolution
                col, = ReinterpretAxis(None,
                                       in_order=OrderNHWC,
                                       out_order=Order(
                                           [Axis.N, a_kh, a_kw, a_filter]))(x)

                new_y, = Tensordot(
                    None, [[[a_kh, a_kw, a_filter], [a_kh, a_kw, a_filter]],
                           [a_kh, a_kw, a_filter]])(col, w)

            else:
                # General convolution
                col, = Im2Col(None,
                              ksize=op.ksize,
                              stride=op.stride,
                              padding=op.padding,
                              dilation_rate=op.dilation_rate)(x)
                col, = ReinterpretAxis(
                    None,
                    in_order=OrderNHWC,
                    out_order=Order([Axis.N, Axis.H, Axis.W, a_filter]))(col)

                new_y, = Tensordot(None,
                                   [[a_filter], [a_kh, a_kw, a_filter]])(col,
                                                                         w)

            new_y = new_y.transpose(y.order)
            OptimizeRule.replace_variable(graph, new_y, y)

        return graph, flag_changed
Exemplo n.º 2
0
    def optimize(self, graph: Graph) -> Tuple[Graph, bool]:
        flag_changed = False
        for op in traverse.filter_nodes(
                traverse.listup_operators(graph),
                Deconvolution2D):  # type: Deconvolution2D
            x = op.inputs["x"]
            w = op.inputs["w"]
            y = op.outputs["y"]
            flag_changed = True
            op.remove_all()

            a_filter, a_kh, a_kw = Axis(), Axis(), Axis()
            w, = ReinterpretAxis(None,
                                 in_order=OrderNHWC,
                                 out_order=Order(
                                     [Axis.C, a_kh, a_kw, a_filter]))(w)
            x, = ReinterpretAxis(None,
                                 in_order=OrderNHWC,
                                 out_order=Order(
                                     [Axis.N, Axis.H, Axis.W, a_filter]))(x)

            col, = Tensordot(None, axes=a_filter)(x, w)
            col = col.transpose(
                Order([Axis.N, Axis.H, Axis.W, a_kh, a_kw, Axis.C]))
            col = col.reshape(shape=[*col.shape[0:3],
                                     mul(col.shape[3:6])],
                              order=OrderNHWC)

            new_y, = Col2Im(None,
                            ksize=op.ksize,
                            stride=op.stride,
                            padding=op.padding)(col)
            OptimizeRule.replace_variable(graph, new_y.transpose_like(y), y)

        return graph, flag_changed