コード例 #1
0
ファイル: split_variable.py プロジェクト: TarrySingh/webdnn
def _listup_splittable_axis(v: Variable, op: Operator) -> List[Axis]:
    if isinstance(op, (Concat, SplitAxis)):
        return list(v.order.axes)

    elif isinstance(op, Reshape):
        """
        For more detail of this condition check, please see the comment document of `_split_reshape`
        """
        splittable_axes = []  # type: List[Axis]
        v1 = v
        v2 = op.outputs["y"] if v == op.inputs["x"] else op.inputs["x"]

        for a1 in v1.order.axes:
            d1 = mul(v1.shape[v1.order.axes_dict[a1]:])
            d2 = 1
            for a2 in reversed(v2.order.axes):
                d2 *= v2.shape_dict[a2]

                if d2 == d1:
                    splittable_axes.append(a1)
                    continue

                elif d2 > d1:
                    continue

        return splittable_axes

    elif isinstance(op, Im2Col):
        op = op  # type: Im2Col
        if v in op.outputs.values():
            if v.shape_dict[Axis.C] % (op.ksize[0] * op.ksize[1]) == 0:
                return [Axis.N, Axis.H, Axis.W, Axis.C]
            else:
                return [Axis.N, Axis.H, Axis.W]

        else:
            return []

    elif isinstance(op, PartialIm2Col):
        op = op  # type: PartialIm2Col
        if v in op.outputs.values():
            return []

        else:
            return [op.axis]

    elif isinstance(op, Sgemm):
        if v == op.outputs["C"]:
            return []
        else:
            return list(v.order.axes)

    elif isinstance(op, Tensordot):
        if v == op.outputs["C"]:
            return []
        else:
            return list(v.order.axes)

    else:
        return list(attr.axis for attr in op.get_attribute(Tensorwise))
コード例 #2
0
    def __init__(self, op: Operator):
        self.delegate = lambda exp: exp  # type: Callable[[str], str]
        self.has_inline = traverse.check_attribute_match(op, PostInlineInplace)

        if self.has_inline:
            post_inline_inplace = op.get_attribute(PostInlineInplace)[
                0]  # type: PostInlineInplace
            if post_inline_inplace.injected is not None:
                self.delegate = post_inline_inplace.injected.injector
コード例 #3
0
ファイル: tensorwise.py プロジェクト: zhangaz1/webdnn
 def check_splittable(op: Operator, axis: Axis):
     """Check whether op can be split in specified axis"""
     return any(attr.axis == axis for attr in op.get_attribute(Tensorwise))
コード例 #4
0
ファイル: tensorwise.py プロジェクト: fossabot/hash2face
 def check_splittable(op: Operator, axis: Axis):
     return any(attr.axis == axis for attr in op.get_attribute(Tensorwise))
コード例 #5
0
 def check_splittable(cls, op: Operator, axis: Axis):
     return any(attr.axis == axis for attr in op.get_attribute(cls))