def broadcasting_inference_rule(n: Node, symbols, constraints, counter): op_code = None if n.target == operator.add or n.target == torch.add: op_code = op_add elif n.target == operator.mul: op_code = op_mul if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar): my_output, counter = gen_tvar(counter) symbols[n] = my_output e1 = symbols[n.args[0]] e2 = symbols[n.args[1]] return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output) else: raise NotImplementedError('Method not yet implemented') elif isinstance(n.args[0], Node) and (isinstance(n.args[1], int) or isinstance(n.args[1], float)): if isinstance(symbols[n.args[0]], TVar): my_output, counter = gen_tvar(counter) symbols[n] = my_output e1 = symbols[n.args[0]] return [BinConstraintT(my_output, e1, op_eq)], counter elif isinstance(symbols[n.args[0]], DVar): my_output, counter = gen_dvar(counter) symbols[n] = my_output e1 = symbols[n.args[0]] # we will propagate the runtime value here since this is regular addition c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq), BinConstraintD(0, my_output, op_leq)]) return [c], counter elif isinstance(n.args[1], Node) and (isinstance(n.args[0], int) or isinstance(n.args[1], float)): if isinstance(symbols[n.args[1]], TVar): my_output, counter = gen_tvar(counter) symbols[n] = my_output e2 = symbols[n.args[1]] return [BinConstraintT(my_output, e2, op_eq)], counter elif isinstance(symbols[n.args[1]], DVar): my_output, counter = gen_dvar(counter) symbols[n] = my_output e2 = symbols[n.args[1]] # we will propagate the runtime value here since this is regular addition c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq), BinConstraintD(0, my_output, op_leq)]) return [c], counter else: raise NotImplementedError('Method not yet implemented') else: # TODO generate add constraints for scalar addition raise NotImplementedError('Addition not yet implemented')
def view_inference_rule(n: Node, symbols, constraints, counter): """ Similar to reshape but with an extra condition on the strides """ assert isinstance(n.args[0], Node) # generate the new variable my_view, counter = gen_tvar(counter) symbols[n] = my_view src_var = symbols[n.args[0]] t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape t2_type = [] num_constraints = [] for t in t2: if t == -1: var, counter = gen_dvar(counter) t2_type.append(var) num_constraints.append(BinConstraintD(var, Dyn, op_neq)) else: num_constraints.append(BinConstraintD(t, Dyn, op_neq)) t2_type.append(t) t2_type = TensorType(t2_type) # type: ignore[assignment] c1 = BinConstraintT(my_view, t2_type, op_eq) c2 = CanReshape(src_var, t2_type) # TODO: add the extra check mentioned here: # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view return [c1, c2] + num_constraints, counter # type: ignore[operator]
def arange_inference_rule(n: Node, symbols, constraints, counter): start = 0 step = 1 if len(n.args) == 1: end = symbols[n.args[0]] else: raise NotImplementedError('Not yet implemented') # int((end - start) / step) d1, counter = gen_dvar(counter) size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq) arange, counter = gen_tvar(counter) symbols[n] = arange # either the a parameter is a number or it is Dyn c1 = Disj([BinConstraintD(end, Dyn, op_eq), BinConstraintD(start, Dyn, op_eq), BinConstraintD(step, Dyn, op_eq)]) c2 = BinConstraintD(d1, Dyn, op_eq) both_dyn = Conj([c1, c2]) c11 = Conj([BinConstraintD(end, Dyn, op_neq), BinConstraintD(start, Dyn, op_neq), BinConstraintD(step, Dyn, op_neq)]) c22 = BinConstraintD(d1, Dyn, op_neq) both_numbers = Conj([c11, c22, size_constraint]) return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
def gt_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) or isinstance(n.args[0], int) assert isinstance(n.args[1], Node) or isinstance(n.args[1], int) # We make sure this node will not be used again. We do not # generate a constraint about that node. Only about the operands. e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): if isinstance(e1, TVar) and isinstance(e2, TVar): gt_tensor, counter = gen_tvar(counter) symbols[n] = gt_tensor return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor) elif isinstance(e1, DVar) and isinstance(e2, DVar): # This is meant to be used for flow analysis only gt_constraint = BinConstraintD(e1, e2, op_gt) my_gt, counter = gen_bvar(counter) equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) return [equality_constraint], counter else: raise RuntimeError('Sort Mismatch') elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): if isinstance(e1, DVar): # This is meant to be used for flow analysis only gt_constraint = BinConstraintD(e1, e2, op_gt) my_gt, counter = gen_bvar(counter) equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) return [equality_constraint], counter elif isinstance(e1, TVar) and isinstance(e2, int): # then we made the wrong assumption about the argument being a tensor # so we should fix the assumption warnings.warn( f'Made the wrong assumption for node {n}. Correctness not guaranteed.' ) new_e1, counter = gen_dvar(counter) symbols[n.args[0]] = new_e1 symbols[n.args[0]] gt_constraint = BinConstraintD(new_e1, e2, op_gt) my_gt, counter = gen_bvar(counter) equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) return [equality_constraint], counter else: raise NotImplementedError('Method not yet implemented') else: raise NotImplementedError('Method not yet implemented')
def bmm_inference_rule(n: Node, symbols, constraints, counter): """ Constraints that match the input to a size 3 tensor and switch the dimensions according to the rules of batch multiplication """ assert isinstance(n.args[0], Node) assert isinstance(n.args[1], Node) bmm_output, counter = gen_tvar(counter) symbols[n] = bmm_output bmm_input1 = symbols[n.args[0]] bmm_input2 = symbols[n.args[1]] dims_input1, counter = gen_tensor_dims(3, counter) dims_input2, counter = gen_tensor_dims(3, counter) inputs_dyn = Conj([ BinConstraintT(bmm_input1, Dyn, op_eq), BinConstraintT(bmm_input2, Dyn, op_eq), BinConstraintT(bmm_output, Dyn, op_eq) ]) input1_dyn = Conj([ BinConstraintT(bmm_input1, Dyn, op_eq), BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq) ]) input2_dyn = Conj([ BinConstraintT(bmm_input2, Dyn, op_eq), BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq) ]) consistency_constraints = [ BinConstraintD(dims_input1[0], dims_input2[0], op_consistency) ] batch_size, counter = gen_dvar(counter) inputs_are_tensors = Conj([ BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), BinConstraintT( bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq), *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0]) ]) return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter
def generate_calc_product(constraint, counter): """ Transform flatten constraints """ start = constraint.start end = constraint.end dims = constraint.dims_to_flatten flattened = constraint.flattened n = len(constraint.dims_to_flatten) # this will be evaluated right here boundary_check = (0 <= start and start < end and end <= n) c_boundary = T() if boundary_check else F() lhs = dims[0:start] rhs = dims[end:] mid = dims[start:end] all_possibilities = generate_all_int_dyn_dim_possibilities(mid) all_constraints = [] for p in all_possibilities: p = list(p) # this tells us there is a dynamic variable contains_dyn = not (all([constraint.op == op_neq for constraint in p])) if contains_dyn: mid_var = [Dyn] total_constraints = lhs + mid_var + rhs if len(total_constraints) > 4: all_constraints.append(F()) else: all_constraints.append( Conj([ BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq) ] + p)) else: new_var, counter = gen_dvar(counter) mid_eq_prod = Conj([ BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq) ]) mid_var = [new_var] total_constraints = lhs + mid_var + rhs if len(total_constraints) > 4: all_constraints.append(F()) else: all_constraints.append( Conj([ BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod ] + p)) return Conj([Disj(all_constraints), c_boundary]), counter
def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter): assert isinstance(n.args[0], Node) avg_pool, counter = gen_tvar(counter) symbols[n] = avg_pool input_var = symbols[n.args[0]] # dim vars d1, counter = gen_dvar(counter) d2, counter = gen_dvar(counter) d3, counter = gen_dvar(counter) d4, counter = gen_dvar(counter) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq) return [c1, c2, *nat_constraints], counter
def getitem_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) # dimension output case if isinstance(n.args[1], int): # create and store the new dimension variable get_item_output, counter = gen_dvar(counter) symbols[n] = get_item_output # retreive arg variables get_item_arg = symbols[n.args[0]] assert isinstance(get_item_arg, TVar) # if the input is dynamic, we accept any index and return # a dynamic dimension as output input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) output_dyn = BinConstraintD(get_item_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) # if the input is a tensor, # generate a getItem constraint which will be expanded based on the # tensor dimension. c2 = [ GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK) ] # since the output is a dimension, we make sure it's a natural number # added as a conjunction to the disjuction of c2 c3 = BinConstraintD(0, get_item_output, op_leq) return [Disj([c1, Conj([Disj(c2), c3])])], counter # tensor output case elif isinstance(n.args[1], tuple): # create and store the new tensor variable get_item_output, counter = gen_tvar(counter) symbols[n] = get_item_output # retreive arg variables get_item_arg = symbols[n.args[0]] assert isinstance(get_item_arg, TVar) input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment] c1 = Conj([input_dyn, output_dyn]) c2 = [ GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK) ] # type: ignore[misc] return [Disj([c1, *c2])], counter else: raise RuntimeError('Method not yet implemented')
def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter): assert isinstance(n.args[0], Node) # generate the new variable batchnorm_output, counter = gen_tvar(counter) symbols[n] = batchnorm_output batchnorm_input = symbols[n.args[0]] # dim vars d1, counter = gen_dvar(counter) d2, counter = gen_dvar(counter) d3, counter = gen_dvar(counter) d4, counter = gen_dvar(counter) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching) c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq) return [c1, c2, *nat_constraints], counter
def torch_dim_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) my_dim, counter = gen_dvar(counter) symbols[n] = my_dim input = symbols[n.args[0]] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintD(my_dim, Dyn, op_eq) c1 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) c_tensor_i = Conj([ BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), BinConstraintD(my_dim, i, op_eq) ]) c1.append(c_tensor_i) return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter
def add_inference_rule(n: Node, symbols, constraints, counter): if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): if isinstance(symbols[n.args[0]], TVar) and isinstance( symbols[n.args[1]], TVar): my_add, counter = gen_tvar(counter) symbols[n] = my_add e1 = symbols[n.args[0]] e2 = symbols[n.args[1]] return gen_broadcasting_constraints(e1, e2, symbols, counter, my_add) else: raise NotImplementedError('Method not yet implemented') elif isinstance(n.args[0], Node) and isinstance(n.args[1], int): if isinstance(symbols[n.args[0]], TVar): my_add, counter = gen_tvar(counter) symbols[n] = my_add e1 = symbols[n.args[0]] return [BinConstraintT(my_add, e1, op_eq)], counter elif isinstance(symbols[n.args[0]], DVar): my_add, counter = gen_dvar(counter) symbols[n] = my_add e1 = symbols[n.args[0]] # we will propagate the runtime value here since this is regular addition c = Conj([ BinConstraintD(my_add, BinConstraintD(e1, n.args[1], op_add), op_eq), BinConstraintD(0, my_add, op_leq) ]) return [c], counter else: raise NotImplementedError('Method not yet implemented') else: # TODO generate add constraints for scalar addition raise NotImplementedError('Addition not yet implemented')
def size_inference_rule(n: Node, symbols, constraints, counter): """ The constraint is just lhs = rhs. Ex: size = input_ids.size() """ if len(n.args) == 1: # generate the new variable size, counter = gen_tvar(counter) symbols[n] = size input = symbols[n.args[0]] c = BinConstraintT(input, size, op_eq) return [c], counter elif len(n.args) == 2: # TODO: review this rule; should input = dyn; output = dyn be included here? if isinstance(n.args[1], int): # generate the new variable size_index, counter = gen_dvar(counter) symbols[n] = size_index input = symbols[n.args[0]] c2 = [ GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK) ] c3 = BinConstraintD(0, size_index, op_leq) input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintD(size_index, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) return [Disj([c1, Conj([Disj(c2), c3])])], counter else: raise NotImplementedError else: raise NotImplementedError
def generate_binconstraint_t(constraint, counter): """ Transform binary constraints for tensors """ # precision constraints if constraint.op == op_precision: if constraint.lhs == Dyn: return T(), counter elif isinstance(constraint.lhs, TensorType): is_fully_static = all([d != Dyn for d in constraint.lhs.__args__]) if is_fully_static: return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter else: new_dims = [] for _ in range(len(constraint.lhs.__args__)): dim, counter = gen_dvar(counter) new_dims.append(dim) new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \ [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \ [BinConstraintD(1, new_dim, op_leq) for new_dim in new_dims] return Conj(new_dim_constraints), counter # matching elif constraint.op == op_matching: assert isinstance(constraint.rhs, TensorType) d1 = constraint.rhs.__args__[0] d2 = constraint.rhs.__args__[1] d3 = constraint.rhs.__args__[2] d4 = constraint.rhs.__args__[3] conj = [ BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintD(d1, Dyn, op_eq), BinConstraintD(d2, Dyn, op_eq), BinConstraintD(d3, Dyn, op_eq), BinConstraintD(d4, Dyn, op_eq) ] return Disj([ Conj(conj), BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq) ]), counter elif constraint.op == op_consistency: c_dyn = Disj([ BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq) ]) [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4 ], counter = gen_consistency_constraints(constraint, counter) return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter elif constraint.op == op_leq: assert isinstance(constraint.rhs, int) disj = [] for i in range(1, constraint.rhs + 1): dims = [] for j in range(1, i + 1): dim_var, counter = gen_dvar(counter) dims.append(dim_var) disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq)) return Disj(disj), counter else: return constraint, counter
def neq_inference_rule(n: Node, symbols, constraints, counter): """ Translates to inconsistent in gradual types. To prove inequality, we should prove that tensors are either different sizes or disagree on at least one dimension This is a WIP (works when the condition is false. We are working on making this operation work when the condition is true as well) """ assert isinstance(n.args[0], Node) assert isinstance(n.args[1], tuple) # implementing for size 3 and 4 if len(n.args[1]) == 3: assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int) assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int) assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int) lhs = symbols[n.args[0]] b, counter = gen_tensor_dims(4, counter) input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq) d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] # dimensions not equal my_ne, counter = gen_bvar(counter) neq_1 = BinConstraintD(d1, b[0], op_neq) neq_2 = BinConstraintD(d2, b[1], op_neq) neq_3 = BinConstraintD(d3, b[2], op_neq) # dimensions inconsistent dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1]) dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2]) dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3]) dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3]) # we are covering size 3 and 4 only for now ne_constraint = Conj([input_is_size3, dims_inconsistent]) my_ne, counter = gen_bvar(counter) equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) elif len(n.args[1]) == 4: assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int) assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int) assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int) assert isinstance(n.args[1][3], Node) or isinstance(n.args[1][3], int) lhs = symbols[n.args[0]] b1, counter = gen_dvar(counter) b2, counter = gen_dvar(counter) b3, counter = gen_dvar(counter) b4, counter = gen_dvar(counter) input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq) d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]] # dimensions not equal my_ne, counter = gen_bvar(counter) neq_1 = BinConstraintD(d1, b1, op_neq) neq_2 = BinConstraintD(d2, b2, op_neq) neq_3 = BinConstraintD(d3, b3, op_neq) neq_4 = BinConstraintD(d4, b4, op_neq) # dimensions to inconsistent dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1]) dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2]) dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3]) dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4]) dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4]) ne_constraint = Conj([input_is_size4, dims_inconsistent]) my_ne, counter = gen_bvar(counter) equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) else: raise NotImplementedError('Method not yet implemented') return [equality_constraint], counter