def flatten_refinement_rule(n: Node):
    assert isinstance(n.args[0], Node)

    eq_const = []

    start_dim = 1
    end_dim = -1

    if len(n.args) > 1:
        assert isinstance(n.args[1], int)
        start_dim = n.args[1]

    if len(n.args) > 2:
        assert isinstance(n.args[2], int)
        end_dim = n.args[2]

    if isinstance(n.type, TensorType) and isinstance(n.args[0].type,
                                                     TensorType):
        l = len(n.type.__args__)
        arg_type = n.args[0].type
        start_dim = l if start_dim == -1 else start_dim
        end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1

        for t1, t2 in zip(n.type.__args__[0:start_dim],
                          arg_type.__args__[0:start_dim]):
            eq_const.append(Equality(t1, t2))

        for t1, t2 in zip(n.type.__args__[end_dim:],
                          arg_type.__args__[end_dim:]):
            eq_const.append(Equality(t1, t2))
    return eq_const
示例#2
0
    def test_symbolic_add_with_broadcast(self):
        class M(torch.nn.Module):
            def forward(self, x: TensorType((1, 2, 3, Dyn)),
                        y: TensorType((2, 3, 4))):
                return torch.add(x, y)

        module = M()
        symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
        tc = GraphTypeChecker({}, symbolic_traced)
        tc.type_check()
        infer_symbolic_types(symbolic_traced)
        r = Refine(symbolic_traced)
        r.refine()

        assert r.constraints == [
            Equality(1, 1), Equality(2, 2),
            Equality(3, 3)
        ]
        # note that there is no equality constraint between dyn and 4 because
        # dyn could be 4 or 1

        infer_symbolic_types(symbolic_traced)

        expected_ph_types = [
            TensorType((1, 2, 3, sympy.symbols('~0'))),
            TensorType((2, 3, 4)),
            TensorType((1, 2, 3, sympy.symbols('~1'))),
            TensorType((1, 2, 3, sympy.symbols('~1')))
        ]
        expected_iter = iter(expected_ph_types)

        for n in symbolic_traced.graph.nodes:
            assert n.type == next(expected_iter)
def flatten_refinement_rule(n: Node):
    """
    Generates equality constraints between the dimensions of the input and output
    that will not be involved in the flatten operation
    """
    assert isinstance(n.args[0], Node)

    eq_const = []

    start_dim = 1
    end_dim = -1

    if len(n.args) > 1:
        assert isinstance(n.args[1], int)
        start_dim = n.args[1]

    if len(n.args) > 2:
        assert isinstance(n.args[2], int)
        end_dim = n.args[2]

    if isinstance(n.type, TensorType) and isinstance(n.args[0].type,
                                                     TensorType):
        l = len(n.type.__args__)
        arg_type = n.args[0].type
        start_dim = l if start_dim == -1 else start_dim
        end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1

        for t1, t2 in zip(n.type.__args__[0:start_dim],
                          arg_type.__args__[0:start_dim]):
            eq_const.append(Equality(t1, t2))

        for t1, t2 in zip(n.type.__args__[end_dim:],
                          arg_type.__args__[end_dim:]):
            eq_const.append(Equality(t1, t2))
    return eq_const
def first_two(n: Node):
    res = []
    assert isinstance(n.args[0], Node)
    arg_type = n.args[0].type
    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
        args1 = arg_type.__args__
        args2 = n.type.__args__
        res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])]
    return res
def first_two_eq(n: Node):
    """
    For operations where the first two dimensions of the input and output shape
    are equal
    """
    res = []
    assert isinstance(n.args[0], Node)
    arg_type = n.args[0].type
    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
        args1 = arg_type.__args__
        args2 = n.type.__args__
        res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])]
    return res
def element_wise_eq(n: Node):
    """
    For element-wise operations and handles broadcasting.
    Note that after applying broadcasting to the arguments
    we are able to determine if certain dimensions have not been broadcast
    if they are symbolicallu equal.

    in this case, we can establish equality between those dimensions and the
    corresponding output dimensions.

    Note that it takes two iterations for this result. One iteration to establish
    equality between certain dimensions of the operands (requiring the whole solver
    including unification) and another iteration to establish equality between the operands
    and the resulting type, requiring another round of constraint generation and unificaiton.
    """
    res = []
    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
        arg_type1 = n.args[0].type
        arg_type2 = n.args[1].type
        if isinstance(arg_type1, TensorType) and isinstance(
                arg_type2, TensorType) and isinstance(n.type, TensorType):
            args1, args2 = broadcast_types(arg_type1, arg_type2)
            # by this point, we know that args1 and args2 are the same size.
            a1 = args1.__args__
            a2 = args2.__args__
            a3 = n.type.__args__

            # we would be here in the second iteration where we establish equality
            # between operand type dimensions and the resulting type dimensions
            r = []
            for x, y, z in zip(a1, a2, a3):
                if x == y:
                    r.append(Equality(x, z))
            res = r
    return res
def linear_refinement_rule(n: Node):
    res = []
    assert isinstance(n.args[0], Node)
    arg_type = n.args[0].type
    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
        res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
    return res
示例#8
0
def first_one(n: Node):
    res = []
    assert isinstance(n.args[0], Node)
    arg_type = n.args[0].type
    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
        res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
    return res
def all_eq(n: Node):
    res = []
    assert isinstance(n.args[0], Node)
    arg_type = n.args[0].type
    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
        args1 = arg_type.__args__
        args2 = n.type.__args__
        res = [Equality(args1[i], args2[i]) for i in range(len(args1))]
    return res
def linear_refinement_rule(n: Node):
    """
    The equality constraints are between the first dimension of
    the input and output
    """
    res = []
    assert isinstance(n.args[0], Node)
    arg_type = n.args[0].type
    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
        res = [Equality(arg_type.__args__[0], n.type.__args__[0])]
    return res
def all_eq(n: Node):
    """
    For operations where the input shape is equal to the output shape
    """
    res = []
    assert isinstance(n.args[0], Node)
    arg_type = n.args[0].type
    if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType):
        args1 = arg_type.__args__
        args2 = n.type.__args__
        res = [Equality(args1[i], args2[i]) for i in range(len(args1))]
    return res
示例#12
0
    def test_type_check_flatten3(self):
        class M(torch.nn.Module):
            def forward(self, x: TensorType((2, 3, 4, 5))):
                return torch.flatten(x, start_dim=1, end_dim=3)

        module = M()
        symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
        tc = GraphTypeChecker({}, symbolic_traced)
        tc.type_check()
        for n in symbolic_traced.graph.nodes:
            if n.op == 'output':
                assert n.type == TensorType((2, 60))
        r = Refine(symbolic_traced)
        r.refine()
        c = r.constraints
        assert c == [Equality(2, 2)]
def add_eq(n: Node):
    res = []
    if isinstance(n.args[0], Node) and isinstance(n.args[1], Node):
        arg_type1 = n.args[0].type
        arg_type2 = n.args[1].type
        if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType):
            args1, args2 = broadcast_types(arg_type1, arg_type2)
            # by this point, we know for sure that args1 and args2 are the same size.
            a1 = args1.__args__
            a2 = args2.__args__
            a3 = n.type.__args__
            r = []
            for x, y, z in zip(a1, a2, a3):
                if x == y:
                    r.append(Equality(x, z))
            res = r
    return res