def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): embedding_output, counter = gen_tvar(counter) symbols[n] = embedding_output embedding_input = symbols[n.args[0]] input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK): new_dims, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims) # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases c_tensor_i = Conj([ BinConstraintT(embedding_input, TensorType(new_dims), op_eq), BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq) ] + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def broadcast_types(t1, t2): if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): return t1, t2 if isinstance(t1, TensorType) and isinstance(t2, TensorType): s1 = len(t1.__args__) s2 = len(t2.__args__) new_t1 = list(t1.__args__) new_t2 = list(t2.__args__) # here, we make our tensors the same length if s1 > s2: for i in range(s1 - s2): new_t2.insert(0, 1) elif s2 > s1: for i in range(s2 - s1): new_t1.insert(0, 1) for i, (x, y) in enumerate(zip(new_t1, new_t2)): if x == 1: new_t1[i] = y elif y == 1: new_t2[i] = x (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) return (t1, t2) else: raise TypeError(f'Cannot broadcast types {t1} and {t2}')
def test_type_check_conv2D(self): class BasicBlock(torch.nn.Module): def __init__(self, inplanes, planes, stride=1, norm_layer=None): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = torch.nn.BatchNorm2d self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) def forward(self, x: Dyn): identity = x out: TensorType((2, 2, Dyn, 4)) = self.conv1(x) out += identity return out B = BasicBlock(2, 2) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") tc = GraphTypeChecker({}, traced) tc.type_check() for n in graph.nodes: if n.op == 'placeholder': assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) if n.op == 'call_function': assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) if n.op == 'output': assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn)) if n.op == 'call_module': assert n.type == TensorType((2, 2, Dyn, 4))
def equality_inference_rule(n: Node, symbols, constraints, counter): """ We generate the constraint: input = output """ output, counter = gen_tvar(counter) symbols[n] = output if isinstance(n.args[0], Node): input = symbols[n.args[0]] if isinstance(input, TVar): return [BinConstraintT(input, output, op_eq)], counter # then we have dimension variables else: for arg in n.args: assert isinstance(symbols[arg], DVar) my_size = [symbols[arg] for arg in n.args] return [BinConstraintT(output, TensorType(my_size), op_eq)], counter elif isinstance(n.args[0], tuple): # then the tuple is the size assert len(n.args[0]) <= 4 my_size = [symbols[arg] for arg in n.args[0]] return [BinConstraintT(output, TensorType(my_size), op_eq)], counter else: raise NotImplementedError('Method not yet implemented')
def test_resnet50(self): gm_run = symbolic_trace(resnet50()) sample_input = torch.randn(1, 3, 224, 224) # run our nodes ShapeProp(gm_run).propagate(sample_input) gm_static = symbolic_trace(resnet50()) for n in gm_static.graph.nodes: n.type = None g = GraphTypeChecker({}, gm_static) g.type_check() # here we are checking for consistency with fully dynamic nodes for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes): assert is_consistent(n1.type, TensorType(n2.meta['tensor_meta'].shape)) # here we give the same input as to runtume gm_static_with_types = symbolic_trace(resnet50()) # we initialize our placeholder for n in gm_static_with_types.graph.nodes: if n.op == 'placeholder': n.type = TensorType((1, 3, 224, 224)) g = GraphTypeChecker({}, gm_static_with_types) g.type_check() for n1, n2 in zip(gm_static_with_types.graph.nodes, gm_run.graph.nodes): assert n1.type == TensorType(n2.meta['tensor_meta'].shape)
def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter): """ Input and output sizes should be the same except for the last dimension If the input is Dyn, then so should the output """ assert isinstance(n.args[0], Node) linear_output, counter = gen_tvar(counter) symbols[n] = linear_output linear_input = symbols[n.args[0]] input_dyn = BinConstraintT(linear_input, Dyn, op_eq) output_dyn = BinConstraintT(linear_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) new_dims_rhs_2, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] + add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, module_instance) + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def broadcast_types(t1, t2): if t1 == Dyn or t2 == Dyn: return t1, t2 if isinstance(t1, TensorType) and isinstance(t2, TensorType): s1 = len(t1.__args__) s2 = len(t2.__args__) new_t1 = list(t1.__args__) new_t2 = list(t2.__args__) if abs(s1 - s2) > 1 or s1 == 0 or s2 == 0: raise TypeError(f'Cannot broadcast the tensors {t1} and {t2}') if s1 > s2: new_t2.insert(0, t1.__args__[0]) elif s2 > s1: new_t1.insert(0, t2.__args__[0]) for i, (x, y) in enumerate(zip(new_t1, new_t2)): if x == 1: new_t1[i] = y elif y == 1: new_t2[i] = x else: continue if tuple(new_t1) != t1.__args__ and tuple(new_t2) != t2.__args__: raise TypeError('In-place operations cannot not change shape') return TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) else: raise TypeError(f'Cannot broadcast types {t1} and {t2}')
def gen_consistency_constraints(constraint: Constraint, counter: int): """ Args: constraint: Consistency constraint on tensors counter: for variable tracking Returns: Equality and consistency constraints on dimensions """ all_constraints = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) new_dims_rhs_2, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) c_tensor_i = Conj([ BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq) ] + [ BinConstraintD(d1, d2, op_consistency) for d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2) ] + nat_constraints) all_constraints.append(c_tensor_i) return all_constraints, counter
def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter): """ The output shape differs from the input shape in the last dimension """ assert isinstance(n.args[0], Node) embedding_dim = module_instance.embedding_dim # number embedding_output, counter = gen_tvar(counter) symbols[n] = embedding_output embedding_input = symbols[n.args[0]] input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK): new_dims, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims) # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq), BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def test_symbolic_add_with_broadcast(self): class M(torch.nn.Module): def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))): return torch.add(x, y) module = M() symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) tc = GraphTypeChecker({}, symbolic_traced) tc.type_check() infer_symbolic_types(symbolic_traced) r = Refine(symbolic_traced) r.refine() assert r.constraints == [ Equality(1, 1), Equality(2, 2), Equality(3, 3) ] # note that there is no equality constraint between dyn and 4 because # dyn could be 4 or 1 infer_symbolic_types(symbolic_traced) expected_ph_types = [ TensorType((1, 2, 3, sympy.symbols('~0'))), TensorType((2, 3, 4)), TensorType((1, 2, 3, sympy.symbols('~1'))), TensorType((1, 2, 3, sympy.symbols('~1'))) ] expected_iter = iter(expected_ph_types) for n in symbolic_traced.graph.nodes: assert n.type == next(expected_iter)
def cumsum_inference_rule(n: Node, symbols, constraints, counter): """ Input and output shapes should be equal We should verify that the index is valid """ assert isinstance(n.args[0], Node) arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"] assert isinstance(arg_1, int) output, counter = gen_tvar(counter) symbols[n] = output input = symbols[n.args[0]] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintT(output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims) c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq), BinConstraintT(output, TensorType(new_dims), op_eq)] + [range_check(arg_1, i)] + nat_constraints) c2.append(c_tensor_i) dyn_or_tensor = Disj([c1, Disj(c2)]) return [dyn_or_tensor], counter
def expand_inference_rule(n: Node, symbols, constraints, counter): """ We generate the exact constraints as we do for tensor additions but we constraint the rank of this expression to be equal to len(n.args[1:]) so that only those cases get considered for the output """ assert isinstance(n.args[0], Node) # define the output for expand expand, counter = gen_tvar(counter) symbols[n] = expand # since we do not have two nodes here, we will construct an argument variable e1 = symbols[n.args[0]] e2, counter = gen_tvar(counter) e2_nat_constraints = [] for arg in n.args[1:]: assert isinstance(arg, Node) or isinstance(arg, int) if isinstance(arg, Node): assert isinstance(symbols[arg], DVar) e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq)) e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq) constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand) # constraint the output size dims, counter = gen_tensor_dims(len(n.args[1:]), counter) nat_constraints = gen_nat_constraints(dims) c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints] constraints += c return constraints, counter
def create_equality_constraints_for_broadcasting(e1: TVar, e2: TVar, e11: TVar, e12: TVar, d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]): """ Create equality constraints for when no broadcasting occurs Args: e1: Input 1 e2: Input 2 e11: Broadcasted input 1 e12: Broadcasted input 2 d1: Variables that store dimensions for e1 d2: Variables that store dimensions for e2 d11: Variables that store dimensions for e11 d12: Variables that store dimensions for e22 Returns: Four equality constraints """ e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq) e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq) e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq) e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq) return [e1_tensor, e11_tensor, e2_tensor, e12_tensor]
def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter): """ Input and output shapes should be equal. Input should be consistent with the normalized_shape """ assert isinstance(n.args[0], Node) output, counter = gen_tvar(counter) symbols[n] = output input = symbols[n.args[0]] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintT(output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs) c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq), BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] + add_layer_norm_constraints(new_dims_rhs, list(module_instance.normalized_shape)) + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def transform_get_item_tensor(constraint, counter): """ When the index is a tuple, then the output will be a tensor TODO: we have to check if this is the case for all HF models The cases we are covrering here are a tuple with one of: - slice with default argument - None None appends 1 to the input tensor dimensions so each occurrence of 'None' increases the rank by 1 slice with default arguments does not change the rank """ assert isinstance(constraint.index_tuple, tuple) # generate a result tensor of the expected size dims, counter = gen_tensor_dims(constraint.tensor_size, counter) nat_constraints = gen_nat_constraints(dims) # generate a place-holder list of the right rank # where "slice" does not contribute to the rank and "None" does none_c = constraint.index_tuple.count(None) resulting_tensor_dims = (none_c + len(dims)) * [None] dim_index = 0 for i in range(len(constraint.index_tuple)): # append 1 to the right location of the resulting tensor if constraint.index_tuple[i] is None: resulting_tensor_dims[i] = 1 elif constraint.index_tuple[i] == slice(None, None, None): pass else: raise NotImplementedError('Method not yet implemented') # append the remaining dimensions to the right location dim_index = 0 for i in range(len(resulting_tensor_dims)): if resulting_tensor_dims[i] is None: resulting_tensor_dims[i] = dims[dim_index] dim_index += 1 # check if the index is valid is_valid_index = valid_index_tensor(constraint.index_tuple, dims) # check if the resulting tensor is within bounds if len(resulting_tensor_dims) > 4: return F(), counter else: constraints = [ BinConstraintT(constraint.input_var, TensorType(dims), op_eq), BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), *nat_constraints, is_valid_index ] return Conj(constraints), counter
def generate_calc_product(constraint, counter): """ Transform flatten constraints """ start = constraint.start end = constraint.end dims = constraint.dims_to_flatten flattened = constraint.flattened n = len(constraint.dims_to_flatten) # this will be evaluated right here boundary_check = (0 <= start and start < end and end <= n) c_boundary = T() if boundary_check else F() lhs = dims[0:start] rhs = dims[end:] mid = dims[start:end] all_possibilities = generate_all_int_dyn_dim_possibilities(mid) all_constraints = [] for p in all_possibilities: p = list(p) # this tells us there is a dynamic variable contains_dyn = not (all([constraint.op == op_neq for constraint in p])) if contains_dyn: mid_var = [Dyn] total_constraints = lhs + mid_var + rhs if len(total_constraints) > 4: all_constraints.append(F()) else: all_constraints.append( Conj([ BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq) ] + p)) else: new_var, counter = gen_dvar(counter) mid_eq_prod = Conj([ BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq) ]) mid_var = [new_var] total_constraints = lhs + mid_var + rhs if len(total_constraints) > 4: all_constraints.append(F()) else: all_constraints.append( Conj([ BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod ] + p)) return Conj([Disj(all_constraints), c_boundary]), counter
def test_precision(self): """ Test the consistency relation. """ self.assertTrue(is_more_precise(TensorType((1, 2, 3)), TensorType((1, Dyn, 3)))) self.assertTrue(is_more_precise(int, Dyn)) self.assertTrue(is_more_precise(int, int)) self.assertFalse(is_more_precise(TensorType((1, 2, 3)), TensorType((1, 2, 3, 5)))) self.assertFalse(is_more_precise(TensorType((1, 2, 3)), int))
def test_flatten_fully_static(self): annotation_list = [ Dyn, TensorType((2, 5, 6, 9)), TensorType((10, 15, 13, 14)), TensorType((10, Dyn, 13, 14)), TensorType((Dyn, Dyn, Dyn, 10)) ] input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14), (2, 2, 10, 10)] intermediate_list = [ Dyn, (2, 5, 6, 9), (10, 15, 13, 14), (10, 15, 13, 14), (2, 2, 10, 10) ] start_dim = [1, 2, 1, 2, 0] end_dim = [1, 3, 3, 3, -2] for i in range(5): annotation = annotation_list[i] input = input_list[i] # intermediate_type = intermediate_list[i] class BasicBlock(torch.nn.Module): def __init__(self, start, end): super(BasicBlock, self).__init__() self.start = start self.end = end def forward(self, x): out = torch.flatten(x, self.start, self.end) return out B = BasicBlock(start_dim[i], end_dim[i]) ast_rewriter = RewritingTracer() graph = ast_rewriter.trace(B) traced = GraphModule(ast_rewriter.root, graph, "gm") # annotate our argument for n in graph.nodes: if n.op == 'placeholder': n.type = annotation b = B.forward(torch.rand(input)) tc = GraphTypeChecker({}, traced) tc.type_check() for n in graph.nodes: if n.op == 'output': assert is_consistent(n.type, TensorType(b.size()))
def index_select_inference_rule(n: Node, symbols, constraints, counter): """ We constrain the second argument to a vector or Dyn. The output replaces the input with the shape of the vector at the position given by the index (first argument) """ # print(n.args) assert isinstance(n.args[0], Node) assert isinstance(n.args[1], int) assert isinstance(n.args[2], Node) index_select, counter = gen_tvar(counter) symbols[n] = index_select dims, counter = gen_tensor_dims(1, counter) # equality constraint is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select) for i in range(MAX_TENSOR_RANK)])]) c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) for i in range(MAX_TENSOR_RANK)])]) return [Disj([c2, c3])], counter
def transpose_inference_rule(n: Node): """ We check that dimentions for the transpose operations are within range of the tensor type of the node """ if n.target == torch.transpose: assert isinstance(n.args[0], Node) t = n.args[0].type assert isinstance(n.args[1], int) assert isinstance(n.args[2], int) dim1, dim2 = n.args[1], n.args[2] if t == Dyn: n.type = Dyn return n.type elif isinstance(t, TensorType): if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__): new_type = list(t.__args__) new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1] final = TensorType(new_type) n.type = get_greatest_upper_bound(n.type, final) return n.type else: raise TypeError( f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}' ) else: raise TypeError( f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
def view_inference_rule(n: Node, symbols, constraints, counter): """ Similar to reshape but with an extra condition on the strides """ assert isinstance(n.args[0], Node) # generate the new variable my_view, counter = gen_tvar(counter) symbols[n] = my_view src_var = symbols[n.args[0]] t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape t2_type = [] num_constraints = [] for t in t2: if t == -1: var, counter = gen_dvar(counter) t2_type.append(var) num_constraints.append(BinConstraintD(var, Dyn, op_neq)) else: num_constraints.append(BinConstraintD(t, Dyn, op_neq)) t2_type.append(t) t2_type = TensorType(t2_type) # type: ignore[assignment] c1 = BinConstraintT(my_view, t2_type, op_eq) c2 = CanReshape(src_var, t2_type) # TODO: add the extra check mentioned here: # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view return [c1, c2] + num_constraints, counter # type: ignore[operator]
def transform_get_item(constraint, counter): """ generate an equality of the form: t = [a1, ..., an] then generate constraints that check if the given index is valid given this particular tensor size. If the index is valid, generate a constraint to get the item Note that we already handled the Dyn input case in the previous step. Args: constraint: GetItem which assumes we are getting an item from a tensor (not Dyn) counter: variable tracking Returns: simplified constraints for GetItem """ dims, counter = gen_tensor_dims(constraint.tensor_size, counter) nat_constraints = gen_nat_constraints(dims) is_valid_index = valid_index(constraint.index, dims) all_constraints = [ BinConstraintT(constraint.input_var, TensorType(dims), op_eq), *nat_constraints, is_valid_index ] # if the index is valid, we generate a constraint for getting an item # otherwise this clause will have been UNSAT due to the wrong index if is_valid_index == T(): all_constraints.append( BinConstraintD(constraint.res, dims[constraint.index], op_eq)) return Conj(all_constraints), counter
def transpose_inference_rule(n: Node): if n.target == torch.transpose: assert isinstance(n.args[0], Node) t = n.args[0].type assert isinstance(n.args[1], int) assert isinstance(n.args[2], int) dim1, dim2 = n.args[1], n.args[2] if t == Dyn: n.type = Dyn return n.type elif isinstance(t, TensorType): if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__): new_type = list(t.__args__) new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1] final = TensorType(new_type) n.type = final return n.type else: raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') else: raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}')
def generate_calc_conv(constraint, counter): d, counter = gen_tensor_dims(4, counter) conv_result = TensorType([d[0], d[1], d[2], d[3]]) # the convolution result is a tensor of size 4 c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq) # the second dimension of the output is equal to the output channels c2 = Conj([ BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq) ]) # the input corresponds to the output in the first dimension of the convolution c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) c4, c5 = calc_last_two_dims(constraint, d) leq_constraints = Conj([ BinConstraintD(0, d[0], op_leq), BinConstraintD(0, d[1], op_leq), BinConstraintD(0, d[2], op_leq), BinConstraintD(0, d[3], op_leq) ]) return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter): assert isinstance(n.args[0], Node) my_conv, counter = gen_tvar(counter) symbols[n] = my_conv input_var = symbols[n.args[0]] # dim vars [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) # c1 = Matching(input_var, TensorType([d1, d2, d3, d4])) c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) # c2 = DConsistency(module_instance.in_channels, d2) c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency) c3 = CalcConv(my_conv, input_var, module_instance.out_channels, module_instance.kernel_size, module_instance.padding, module_instance.stride, module_instance.dilation, [d1, d2, d3, d4]) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) return [c1, c2, c3, *nat_constraints], counter
def conv2d_inference_rule(n: Node, module_instance): """ Given a Conv2D instance and a node check the following conditions: - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W) - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') - x_2 is consistent with the module's in_channels - let o = (x_1, out_channels, H_out, W_out) then the output is the greatest upper bound of o and the existing node type t'. """ assert isinstance(n.args[0], Node) n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) arg_type = n.args[0].type curr_node_type = expand_to_tensor_dim(n.type, 4) if is_consistent(arg_type.__args__[1], module_instance.in_channels): w_in = arg_type.__args__[3] h_in = arg_type.__args__[2] h_out = calculate_out_dimension(h_in, module_instance, 0) w_out = calculate_out_dimension(w_in, module_instance, 1) new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out)) gub = get_greatest_upper_bound(new_type, curr_node_type) n.type = gub return n.type else: raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}')
def test_type_check_add_with_broadcast(self): class M(torch.nn.Module): def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))): return torch.add(x, y) module = M() symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) tc = GraphTypeChecker({}, symbolic_traced) tc.type_check() expected_ph_types = [TensorType((1, 2, 3, Dyn)), TensorType((1, 2, 3, 4)), TensorType((1, 2, 3, Dyn)), TensorType((1, 2, 3, Dyn))] expected_iter = iter(expected_ph_types) for n in symbolic_traced.graph.nodes: assert n.type == next(expected_iter)
def arange_inference_rule(n: Node, symbols, constraints, counter): start = 0 step = 1 if len(n.args) == 1: end = symbols[n.args[0]] else: raise NotImplementedError('Not yet implemented') # int((end - start) / step) d1, counter = gen_dvar(counter) size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq) arange, counter = gen_tvar(counter) symbols[n] = arange # either the a parameter is a number or it is Dyn c1 = Disj([BinConstraintD(end, Dyn, op_eq), BinConstraintD(start, Dyn, op_eq), BinConstraintD(step, Dyn, op_eq)]) c2 = BinConstraintD(d1, Dyn, op_eq) both_dyn = Conj([c1, c2]) c11 = Conj([BinConstraintD(end, Dyn, op_neq), BinConstraintD(start, Dyn, op_neq), BinConstraintD(step, Dyn, op_neq)]) c22 = BinConstraintD(d1, Dyn, op_neq) both_numbers = Conj([c11, c22, size_constraint]) return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
def substitute_solution_one_type(mapping, t): """ Apply the most general unifier to a type """ if isinstance(t, Var): if t in mapping.keys(): return mapping[t] else: return t elif isinstance(t, TensorType): new_type = [] for typ in t.__args__: if typ in mapping.keys(): new_type.append(mapping[typ]) else: new_type.append(typ) return TensorType(tuple(new_type)) elif isinstance(t, list): new_type = [] for typ in t: new_type.append(substitute_solution_one_type(mapping, typ)) return new_type elif isinstance(t, tuple): new_type = [] for typ in t: new_type.append(substitute_solution_one_type(mapping, typ)) return tuple(new_type) else: return t
def test_type_check_transpose_true(self): class M(torch.nn.Module): def forward(self, x: TensorType((1, 2, 3, 5))): return torch.transpose(x, 0, 1) module = M() symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) tc = GraphTypeChecker({}, symbolic_traced) self.assertTrue(tc.type_check()) for n in symbolic_traced.graph.nodes: if n.op == 'call_function': assert n.type == TensorType([2, 1, 3, 5]) if n.op == 'output': assert n.type == TensorType([2, 1, 3, 5]) if n.op == 'x': assert n.placeholder == TensorType([1, 2, 3, 5])