def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): embedding_output, counter = gen_tvar(counter) symbols[n] = embedding_output embedding_input = symbols[n.args[0]] input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK): new_dims, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims) # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases c_tensor_i = Conj([ BinConstraintT(embedding_input, TensorType(new_dims), op_eq), BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq) ] + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def index_select_inference_rule(n: Node, symbols, constraints, counter): """ We constrain the second argument to a vector or Dyn. The output replaces the input with the shape of the vector at the position given by the index (first argument) """ # print(n.args) assert isinstance(n.args[0], Node) assert isinstance(n.args[1], int) assert isinstance(n.args[2], Node) index_select, counter = gen_tvar(counter) symbols[n] = index_select dims, counter = gen_tensor_dims(1, counter) # equality constraint is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select) for i in range(MAX_TENSOR_RANK)])]) c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) for i in range(MAX_TENSOR_RANK)])]) return [Disj([c2, c3])], counter
def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter): """ Input and output sizes should be the same except for the last dimension If the input is Dyn, then so should the output """ assert isinstance(n.args[0], Node) linear_output, counter = gen_tvar(counter) symbols[n] = linear_output linear_input = symbols[n.args[0]] input_dyn = BinConstraintT(linear_input, Dyn, op_eq) output_dyn = BinConstraintT(linear_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) new_dims_rhs_2, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] + add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, module_instance) + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter): """ Input and output shapes should be equal. Input should be consistent with the normalized_shape """ assert isinstance(n.args[0], Node) output, counter = gen_tvar(counter) symbols[n] = output input = symbols[n.args[0]] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintT(output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs) c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq), BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] + add_layer_norm_constraints(new_dims_rhs, list(module_instance.normalized_shape)) + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def arange_inference_rule(n: Node, symbols, constraints, counter): start = 0 step = 1 if len(n.args) == 1: end = symbols[n.args[0]] else: raise NotImplementedError('Not yet implemented') # int((end - start) / step) d1, counter = gen_dvar(counter) size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq) arange, counter = gen_tvar(counter) symbols[n] = arange # either the a parameter is a number or it is Dyn c1 = Disj([BinConstraintD(end, Dyn, op_eq), BinConstraintD(start, Dyn, op_eq), BinConstraintD(step, Dyn, op_eq)]) c2 = BinConstraintD(d1, Dyn, op_eq) both_dyn = Conj([c1, c2]) c11 = Conj([BinConstraintD(end, Dyn, op_neq), BinConstraintD(start, Dyn, op_neq), BinConstraintD(step, Dyn, op_neq)]) c22 = BinConstraintD(d1, Dyn, op_neq) both_numbers = Conj([c11, c22, size_constraint]) return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
def generate_calc_conv(constraint, counter): d, counter = gen_tensor_dims(4, counter) conv_result = TensorType([d[0], d[1], d[2], d[3]]) # the convolution result is a tensor of size 4 c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq) # the second dimension of the output is equal to the output channels c2 = Conj([ BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq) ]) # the input corresponds to the output in the first dimension of the convolution c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) c4, c5 = calc_last_two_dims(constraint, d) leq_constraints = Conj([ BinConstraintD(0, d[0], op_leq), BinConstraintD(0, d[1], op_leq), BinConstraintD(0, d[2], op_leq), BinConstraintD(0, d[3], op_leq) ]) return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
def cumsum_inference_rule(n: Node, symbols, constraints, counter): """ Input and output shapes should be equal We should verify that the index is valid """ assert isinstance(n.args[0], Node) arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"] assert isinstance(arg_1, int) output, counter = gen_tvar(counter) symbols[n] = output input = symbols[n.args[0]] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintT(output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims) c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq), BinConstraintT(output, TensorType(new_dims), op_eq)] + [range_check(arg_1, i)] + nat_constraints) c2.append(c_tensor_i) dyn_or_tensor = Disj([c1, Disj(c2)]) return [dyn_or_tensor], counter
def no_broadcast_dim_with_index(d1: List[DVar], d2: List[DVar], d3: List[DVar], d4: List[DVar], i: int): """ Args: d1: inpput 1 d2: inpput 2 d3: simulated broadcasting for input 1 d4: simulated broadcasting for input 2 i: the rank of the resulting tensor addition Returns: Constraints for when no broadcasting occurs """ return Conj([ Disj([ Conj([ BinConstraintD(d1[i], 1, op_eq), BinConstraintD(d2[i], 1, op_eq) ]), Conj([ BinConstraintD(d1[i], 1, op_neq), BinConstraintD(d2[i], 1, op_neq) ]) ]), BinConstraintD(d1[i], d3[i], op_eq), BinConstraintD(d2[i], d4[i], op_eq) ])
def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter): """ The output shape differs from the input shape in the last dimension """ assert isinstance(n.args[0], Node) embedding_dim = module_instance.embedding_dim # number embedding_output, counter = gen_tvar(counter) symbols[n] = embedding_output embedding_input = symbols[n.args[0]] input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK): new_dims, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims) # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq), BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def bmm_inference_rule(n: Node, symbols, constraints, counter): """ Constraints that match the input to a size 3 tensor and switch the dimensions according to the rules of batch multiplication """ assert isinstance(n.args[0], Node) assert isinstance(n.args[1], Node) bmm_output, counter = gen_tvar(counter) symbols[n] = bmm_output bmm_input1 = symbols[n.args[0]] bmm_input2 = symbols[n.args[1]] dims_input1, counter = gen_tensor_dims(3, counter) dims_input2, counter = gen_tensor_dims(3, counter) inputs_dyn = Conj([ BinConstraintT(bmm_input1, Dyn, op_eq), BinConstraintT(bmm_input2, Dyn, op_eq), BinConstraintT(bmm_output, Dyn, op_eq) ]) input1_dyn = Conj([ BinConstraintT(bmm_input1, Dyn, op_eq), BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq) ]) input2_dyn = Conj([ BinConstraintT(bmm_input2, Dyn, op_eq), BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq) ]) consistency_constraints = [ BinConstraintD(dims_input1[0], dims_input2[0], op_consistency) ] batch_size, counter = gen_dvar(counter) inputs_are_tensors = Conj([ BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), BinConstraintT( bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq), *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0]) ]) return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter
def broadcasting_inference_rule(n: Node, symbols, constraints, counter): op_code = None if n.target == operator.add or n.target == torch.add: op_code = op_add elif n.target == operator.mul: op_code = op_mul if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar): my_output, counter = gen_tvar(counter) symbols[n] = my_output e1 = symbols[n.args[0]] e2 = symbols[n.args[1]] return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output) else: raise NotImplementedError('Method not yet implemented') elif isinstance(n.args[0], Node) and (isinstance(n.args[1], int) or isinstance(n.args[1], float)): if isinstance(symbols[n.args[0]], TVar): my_output, counter = gen_tvar(counter) symbols[n] = my_output e1 = symbols[n.args[0]] return [BinConstraintT(my_output, e1, op_eq)], counter elif isinstance(symbols[n.args[0]], DVar): my_output, counter = gen_dvar(counter) symbols[n] = my_output e1 = symbols[n.args[0]] # we will propagate the runtime value here since this is regular addition c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq), BinConstraintD(0, my_output, op_leq)]) return [c], counter elif isinstance(n.args[1], Node) and (isinstance(n.args[0], int) or isinstance(n.args[1], float)): if isinstance(symbols[n.args[1]], TVar): my_output, counter = gen_tvar(counter) symbols[n] = my_output e2 = symbols[n.args[1]] return [BinConstraintT(my_output, e2, op_eq)], counter elif isinstance(symbols[n.args[1]], DVar): my_output, counter = gen_dvar(counter) symbols[n] = my_output e2 = symbols[n.args[1]] # we will propagate the runtime value here since this is regular addition c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq), BinConstraintD(0, my_output, op_leq)]) return [c], counter else: raise NotImplementedError('Method not yet implemented') else: # TODO generate add constraints for scalar addition raise NotImplementedError('Addition not yet implemented')
def generate_calc_product(constraint, counter): """ Transform flatten constraints """ start = constraint.start end = constraint.end dims = constraint.dims_to_flatten flattened = constraint.flattened n = len(constraint.dims_to_flatten) # this will be evaluated right here boundary_check = (0 <= start and start < end and end <= n) c_boundary = T() if boundary_check else F() lhs = dims[0:start] rhs = dims[end:] mid = dims[start:end] all_possibilities = generate_all_int_dyn_dim_possibilities(mid) all_constraints = [] for p in all_possibilities: p = list(p) # this tells us there is a dynamic variable contains_dyn = not (all([constraint.op == op_neq for constraint in p])) if contains_dyn: mid_var = [Dyn] total_constraints = lhs + mid_var + rhs if len(total_constraints) > 4: all_constraints.append(F()) else: all_constraints.append( Conj([ BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq) ] + p)) else: new_var, counter = gen_dvar(counter) mid_eq_prod = Conj([ BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq) ]) mid_var = [new_var] total_constraints = lhs + mid_var + rhs if len(total_constraints) > 4: all_constraints.append(F()) else: all_constraints.append( Conj([ BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod ] + p)) return Conj([Disj(all_constraints), c_boundary]), counter
def getitem_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) # dimension output case if isinstance(n.args[1], int): # create and store the new dimension variable get_item_output, counter = gen_dvar(counter) symbols[n] = get_item_output # retreive arg variables get_item_arg = symbols[n.args[0]] assert isinstance(get_item_arg, TVar) # if the input is dynamic, we accept any index and return # a dynamic dimension as output input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) output_dyn = BinConstraintD(get_item_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) # if the input is a tensor, # generate a getItem constraint which will be expanded based on the # tensor dimension. c2 = [ GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK) ] # since the output is a dimension, we make sure it's a natural number # added as a conjunction to the disjuction of c2 c3 = BinConstraintD(0, get_item_output, op_leq) return [Disj([c1, Conj([Disj(c2), c3])])], counter # tensor output case elif isinstance(n.args[1], tuple): # create and store the new tensor variable get_item_output, counter = gen_tvar(counter) symbols[n] = get_item_output # retreive arg variables get_item_arg = symbols[n.args[0]] assert isinstance(get_item_arg, TVar) input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment] c1 = Conj([input_dyn, output_dyn]) c2 = [ GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK) ] # type: ignore[misc] return [Disj([c1, *c2])], counter else: raise RuntimeError('Method not yet implemented')
def generate_broadcasting(constraint, counter): """ Transform broadcasting constraints """ e11, e12 = constraint.res1, constraint.res2 e1, e2 = constraint.input1, constraint.input2 e1_dyn = BinConstraintT(e1, Dyn, op_eq) e2_dyn = BinConstraintT(e2, Dyn, op_eq) # Introduce dimensions e1_equal_e11 = BinConstraintT(e1, e11, op_eq) e2_equal_e12 = BinConstraintT(e2, e12, op_eq) # dyn possibility e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12]) e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12]) # tensor possibility # generate dimensions to create tensors of size 1 final_tensor_1_constraint, _, _, nat_dims_1, counter = \ gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter) # generate dimensions to create tensors of size 2 final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \ final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \ gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter) # generate dimensions to create tensors of size 3 final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \ final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \ gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter) # generate dimensions to create tensors of size 4 final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \ final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \ gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter) final_result = Disj([ e1_dyn_constraint, e2_dyn_constraint, final_tensor_1_constraint, final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, final_tensor_2_constraint_padding_arg2, final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, final_tensor_3_constraint_padding_arg2, final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, final_tensor_4_constraint_padding_arg2 ]) return Conj( [final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter
def transpose_inference_rule(n: Node, symbols, constraints, counter): """ Can be considered as a sequence of two index selects, so we generate constraints accordingly """ assert isinstance(n.args[0], Node) assert isinstance(n.args[1], int) assert isinstance(n.args[2], int) output, counter = gen_tvar(counter) symbols[n] = output from_arg = symbols[n.args[0]] assert isinstance(from_arg, TVar) # input and output are dyn is_dyn = Conj([ BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq) ]) # or input is a tensor and we actually do the replacement c3 = Disj([ Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK) ]) return [Disj([is_dyn, c3])], counter
def flatten_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) # generate the new variable flattened, counter = gen_tvar(counter) symbols[n] = flattened input = symbols[n.args[0]] # set the default start and end dims start_dim = 1 end_dim = -1 if len(n.args) > 1: assert isinstance(n.args[1], int) start_dim = n.args[1] if len(n.args) > 2: assert isinstance(n.args[2], int) end_dim = n.args[2] c1 = BinConstraintT(input, Dyn, op_eq) c2 = BinConstraintT(flattened, Dyn, op_eq) both_dyn = Conj([c1, c2]) const = [] for i in range(1, MAX_TENSOR_RANK + 1): c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter) const.append(c) return [Disj([both_dyn, *const])], counter
def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]): """ Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. We look at all combinations for all dimendions in d1 and d2 Args: d1: input1 dimensions d2: input2 dimensions d11: broadcasted input1 dimensions d12: broadcasted input2 dimensions Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions """ size = len(d1) res2 = [] for i in range(size): t1 = broadcast_dim(d1, d2, d11, d12, i) t2 = broadcast_dim(d2, d1, d12, d11, i) t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i) res2.append(Disj([t1, t2, t3])) return Conj(res2)
def gen_consistency_constraints(constraint: Constraint, counter: int): """ Args: constraint: Consistency constraint on tensors counter: for variable tracking Returns: Equality and consistency constraints on dimensions """ all_constraints = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) new_dims_rhs_2, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) c_tensor_i = Conj([ BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq) ] + [ BinConstraintD(d1, d2, op_consistency) for d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2) ] + nat_constraints) all_constraints.append(c_tensor_i) return all_constraints, counter
def transform_get_item(constraint, counter): """ generate an equality of the form: t = [a1, ..., an] then generate constraints that check if the given index is valid given this particular tensor size. If the index is valid, generate a constraint to get the item Note that we already handled the Dyn input case in the previous step. Args: constraint: GetItem which assumes we are getting an item from a tensor (not Dyn) counter: variable tracking Returns: simplified constraints for GetItem """ dims, counter = gen_tensor_dims(constraint.tensor_size, counter) nat_constraints = gen_nat_constraints(dims) is_valid_index = valid_index(constraint.index, dims) all_constraints = [ BinConstraintT(constraint.input_var, TensorType(dims), op_eq), *nat_constraints, is_valid_index ] # if the index is valid, we generate a constraint for getting an item # otherwise this clause will have been UNSAT due to the wrong index if is_valid_index == T(): all_constraints.append( BinConstraintD(constraint.res, dims[constraint.index], op_eq)) return Conj(all_constraints), counter
def transform_get_item_tensor(constraint, counter): """ When the index is a tuple, then the output will be a tensor TODO: we have to check if this is the case for all HF models The cases we are covrering here are a tuple with one of: - slice with default argument - None None appends 1 to the input tensor dimensions so each occurrence of 'None' increases the rank by 1 slice with default arguments does not change the rank """ assert isinstance(constraint.index_tuple, tuple) # generate a result tensor of the expected size dims, counter = gen_tensor_dims(constraint.tensor_size, counter) nat_constraints = gen_nat_constraints(dims) # generate a place-holder list of the right rank # where "slice" does not contribute to the rank and "None" does none_c = constraint.index_tuple.count(None) resulting_tensor_dims = (none_c + len(dims)) * [None] dim_index = 0 for i in range(len(constraint.index_tuple)): # append 1 to the right location of the resulting tensor if constraint.index_tuple[i] is None: resulting_tensor_dims[i] = 1 elif constraint.index_tuple[i] == slice(None, None, None): pass else: raise NotImplementedError('Method not yet implemented') # append the remaining dimensions to the right location dim_index = 0 for i in range(len(resulting_tensor_dims)): if resulting_tensor_dims[i] is None: resulting_tensor_dims[i] = dims[dim_index] dim_index += 1 # check if the index is valid is_valid_index = valid_index_tensor(constraint.index_tuple, dims) # check if the resulting tensor is within bounds if len(resulting_tensor_dims) > 4: return F(), counter else: constraints = [ BinConstraintT(constraint.input_var, TensorType(dims), op_eq), BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), *nat_constraints, is_valid_index ] return Conj(constraints), counter
def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter): d, counter = gen_tensor_dims(n, counter) c1 = BinConstraintT(input, TensorType(d), op_eq) start_dim = n if start_dim == -1 else abs(start_dim) end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1 c2 = CalcProduct(start_dim, end_dim, flattened, d) nat_constraints = gen_nat_constraints(d) return Conj([c1, c2, *nat_constraints]), counter
def generate_d_gub(constraint, counter): """ Transform greatest upper bound for dimensions into equality constraints """ c1 = Conj([ BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq) ]) c2 = Conj([ BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq) ]) c3 = Conj([ BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq) ]) return Disj([c1, c2, c3]), counter
def generate_conj(constraint, counter): """ Transform conjunctions """ new = [] for c in constraint.conjucts: new_c, counter = transform_constraint(c, counter) new.append(new_c) return Conj(new), counter
def apply_padding(e1_var: TVar, e11: BinConstraintT, e2: BinConstraintT, e12: BinConstraintT, d2: List[DVar], d11: List[DVar], d12: List[DVar], counter: int): """ We are considering the possibility where one input has less dimensions than another input, so we apply padding to the broadcasted results Args: e1_var: Variable representing the first input where padding will be e11: constraint of the form e11 = Tensortype[d1, ..., dn] e2: constraint of the form e2 = Tensortype[d1, ..., dn] e12: constraint of the form e11 = Tensortype[d1, ..., dn] d2: Tensor variables for the second input d11: Tensor variables for the broadcasted first input d12: Tensor variables for the broadcasted second input counter: variable tracking Returns: A new constraint whose goal is to apply padding to the broadcasted result """ res = [] # pad the shorter input with None so we can pass it to the broadcasting helper function for i in range(1, len(d2)): d1, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12) e1 = BinConstraintT(e1_var, TensorType(d1), op_eq) simulate_padding = [None] * (len(d2) - i) assert len(simulate_padding + d1) == len(d2) broadcast_padding = [] # for every padding size, we also consider broadcasting for j in range((len(d2) - i)): broadcast_padding.append( broadcast_dim(simulate_padding, d2, d11, d12, j, True)) # we consider the possibilities for broadcasting for every dimension. Since we already # padded d1, we do not consider it while broadcasting all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding( d1, d2[(len(d2) - i):], d11[(len(d2) - i):], d12[(len(d2) - i):]) # combine all constraints into a conjunction c = Conj([ e1, e11, e2, e12, *broadcast_padding, all_broadcasting_possibilities, *nat_constraints ]) res.append(c) return Disj(res), counter
def gen_all_reshape_possibilities(list_of_dims, target): """ Consider all possibilities what the input dimensions could be (number or dynamic) Then generate the appropriate constraints using multiplication or mod depending on the possibility The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn for the input. Target is fixed because at most one dimension could be dyn. We have different cases for this. Args: list_of_dims: The input list of dimensions target: The tensor we want to reshape to Returns: A disjuncition of transformed reshape constraints """ all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims) all_constraints = [] for p in all_possibilities: to_multiply = [] p = list(p) for constraint in p: assert isinstance(constraint, BinConstraintD) if constraint.op == op_neq: to_multiply.append(constraint.lhs) if not to_multiply: all_constraints.append(Conj(p)) elif len(to_multiply) < len(list_of_dims): all_constraints.append( Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))])) else: all_constraints.append( Conj( p + [BinConstraintD(Prod(list_of_dims), Prod(target), op_eq)])) return Disj(all_constraints)
def torch_dim_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) my_dim, counter = gen_dvar(counter) symbols[n] = my_dim input = symbols[n.args[0]] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintD(my_dim, Dyn, op_eq) c1 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) c_tensor_i = Conj([ BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), BinConstraintD(my_dim, i, op_eq) ]) c1.append(c_tensor_i) return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter
def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): output, counter = gen_tvar(counter) symbols[n] = output input = symbols[n.args[0]] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintT(output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs) c_tensor_i = Conj([ BinConstraintT(input, TensorType(new_dims_rhs), op_eq), BinConstraintT(output, TensorType(new_dims_rhs), op_eq) ] + add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def generate_constraints(self, counter=0): """ Iterate through every node and generate constraints Effect: self.constraints will be populated with the final constraints """ graph = self.traced.graph all_constraints = [] for n in graph.nodes: (constraints, counter) = self.generate_constraints_node(n, counter) all_constraints += constraints return Conj(all_constraints), counter
def size_inference_rule(n: Node, symbols, constraints, counter): """ The constraint is just lhs = rhs. Ex: size = input_ids.size() """ if len(n.args) == 1: # generate the new variable size, counter = gen_tvar(counter) symbols[n] = size input = symbols[n.args[0]] c = BinConstraintT(input, size, op_eq) return [c], counter elif len(n.args) == 2: # TODO: review this rule; should input = dyn; output = dyn be included here? if isinstance(n.args[1], int): # generate the new variable size_index, counter = gen_dvar(counter) symbols[n] = size_index input = symbols[n.args[0]] c2 = [ GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK) ] c3 = BinConstraintD(0, size_index, op_leq) input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintD(size_index, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) return [Disj([c1, Conj([Disj(c2), c3])])], counter else: raise NotImplementedError else: raise NotImplementedError
def generate_calc_maxpool(constraint, counter): """ Transform maxpool constraints """ d, counter = gen_tensor_dims(4, counter) maxpool_result = TensorType([d[0], d[1], d[2], d[3]]) # the maxpool result is a tensor of size 4 c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq) # the input corresponds to the output in the first and second dimension of maxpool c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq) c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) c4, c5 = calc_last_two_dims(constraint, d) leq_constraints = Conj([ BinConstraintD(0, d[0], op_leq), BinConstraintD(0, d[1], op_leq), BinConstraintD(0, d[2], op_leq), BinConstraintD(0, d[3], op_leq) ]) return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter