def flatten_refinement_rule(n: Node): assert isinstance(n.args[0], Node) eq_const = [] start_dim = 1 end_dim = -1 if len(n.args) > 1: assert isinstance(n.args[1], int) start_dim = n.args[1] if len(n.args) > 2: assert isinstance(n.args[2], int) end_dim = n.args[2] if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType): l = len(n.type.__args__) arg_type = n.args[0].type start_dim = l if start_dim == -1 else start_dim end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]): eq_const.append(Equality(t1, t2)) for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]): eq_const.append(Equality(t1, t2)) return eq_const
def test_symbolic_add_with_broadcast(self): class M(torch.nn.Module): def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))): return torch.add(x, y) module = M() symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) tc = GraphTypeChecker({}, symbolic_traced) tc.type_check() infer_symbolic_types(symbolic_traced) r = Refine(symbolic_traced) r.refine() assert r.constraints == [ Equality(1, 1), Equality(2, 2), Equality(3, 3) ] # note that there is no equality constraint between dyn and 4 because # dyn could be 4 or 1 infer_symbolic_types(symbolic_traced) expected_ph_types = [ TensorType((1, 2, 3, sympy.symbols('~0'))), TensorType((2, 3, 4)), TensorType((1, 2, 3, sympy.symbols('~1'))), TensorType((1, 2, 3, sympy.symbols('~1'))) ] expected_iter = iter(expected_ph_types) for n in symbolic_traced.graph.nodes: assert n.type == next(expected_iter)
def flatten_refinement_rule(n: Node): """ Generates equality constraints between the dimensions of the input and output that will not be involved in the flatten operation """ assert isinstance(n.args[0], Node) eq_const = [] start_dim = 1 end_dim = -1 if len(n.args) > 1: assert isinstance(n.args[1], int) start_dim = n.args[1] if len(n.args) > 2: assert isinstance(n.args[2], int) end_dim = n.args[2] if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType): l = len(n.type.__args__) arg_type = n.args[0].type start_dim = l if start_dim == -1 else start_dim end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]): eq_const.append(Equality(t1, t2)) for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]): eq_const.append(Equality(t1, t2)) return eq_const
def first_two(n: Node): res = [] assert isinstance(n.args[0], Node) arg_type = n.args[0].type if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): args1 = arg_type.__args__ args2 = n.type.__args__ res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] return res
def first_two_eq(n: Node): """ For operations where the first two dimensions of the input and output shape are equal """ res = [] assert isinstance(n.args[0], Node) arg_type = n.args[0].type if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): args1 = arg_type.__args__ args2 = n.type.__args__ res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] return res
def element_wise_eq(n: Node): """ For element-wise operations and handles broadcasting. Note that after applying broadcasting to the arguments we are able to determine if certain dimensions have not been broadcast if they are symbolicallu equal. in this case, we can establish equality between those dimensions and the corresponding output dimensions. Note that it takes two iterations for this result. One iteration to establish equality between certain dimensions of the operands (requiring the whole solver including unification) and another iteration to establish equality between the operands and the resulting type, requiring another round of constraint generation and unificaiton. """ res = [] if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): arg_type1 = n.args[0].type arg_type2 = n.args[1].type if isinstance(arg_type1, TensorType) and isinstance( arg_type2, TensorType) and isinstance(n.type, TensorType): args1, args2 = broadcast_types(arg_type1, arg_type2) # by this point, we know that args1 and args2 are the same size. a1 = args1.__args__ a2 = args2.__args__ a3 = n.type.__args__ # we would be here in the second iteration where we establish equality # between operand type dimensions and the resulting type dimensions r = [] for x, y, z in zip(a1, a2, a3): if x == y: r.append(Equality(x, z)) res = r return res
def linear_refinement_rule(n: Node): res = [] assert isinstance(n.args[0], Node) arg_type = n.args[0].type if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): res = [Equality(arg_type.__args__[0], n.type.__args__[0])] return res
def first_one(n: Node): res = [] assert isinstance(n.args[0], Node) arg_type = n.args[0].type if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): res = [Equality(arg_type.__args__[0], n.type.__args__[0])] return res
def all_eq(n: Node): res = [] assert isinstance(n.args[0], Node) arg_type = n.args[0].type if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): args1 = arg_type.__args__ args2 = n.type.__args__ res = [Equality(args1[i], args2[i]) for i in range(len(args1))] return res
def linear_refinement_rule(n: Node): """ The equality constraints are between the first dimension of the input and output """ res = [] assert isinstance(n.args[0], Node) arg_type = n.args[0].type if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): res = [Equality(arg_type.__args__[0], n.type.__args__[0])] return res
def all_eq(n: Node): """ For operations where the input shape is equal to the output shape """ res = [] assert isinstance(n.args[0], Node) arg_type = n.args[0].type if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): args1 = arg_type.__args__ args2 = n.type.__args__ res = [Equality(args1[i], args2[i]) for i in range(len(args1))] return res
def test_type_check_flatten3(self): class M(torch.nn.Module): def forward(self, x: TensorType((2, 3, 4, 5))): return torch.flatten(x, start_dim=1, end_dim=3) module = M() symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) tc = GraphTypeChecker({}, symbolic_traced) tc.type_check() for n in symbolic_traced.graph.nodes: if n.op == 'output': assert n.type == TensorType((2, 60)) r = Refine(symbolic_traced) r.refine() c = r.constraints assert c == [Equality(2, 2)]
def add_eq(n: Node): res = [] if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): arg_type1 = n.args[0].type arg_type2 = n.args[1].type if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType): args1, args2 = broadcast_types(arg_type1, arg_type2) # by this point, we know for sure that args1 and args2 are the same size. a1 = args1.__args__ a2 = args2.__args__ a3 = n.type.__args__ r = [] for x, y, z in zip(a1, a2, a3): if x == y: r.append(Equality(x, z)) res = r return res