def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter): """ Input and output sizes should be the same except for the last dimension If the input is Dyn, then so should the output """ assert isinstance(n.args[0], Node) linear_output, counter = gen_tvar(counter) symbols[n] = linear_output linear_input = symbols[n.args[0]] input_dyn = BinConstraintT(linear_input, Dyn, op_eq) output_dyn = BinConstraintT(linear_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) new_dims_rhs_2, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] + add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, module_instance) + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def gen_consistency_constraints(constraint: Constraint, counter: int): """ Args: constraint: Consistency constraint on tensors counter: for variable tracking Returns: Equality and consistency constraints on dimensions """ all_constraints = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) new_dims_rhs_2, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) c_tensor_i = Conj([ BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq) ] + [ BinConstraintD(d1, d2, op_consistency) for d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2) ] + nat_constraints) all_constraints.append(c_tensor_i) return all_constraints, counter
def bmm_inference_rule(n: Node, symbols, constraints, counter): """ Constraints that match the input to a size 3 tensor and switch the dimensions according to the rules of batch multiplication """ assert isinstance(n.args[0], Node) assert isinstance(n.args[1], Node) bmm_output, counter = gen_tvar(counter) symbols[n] = bmm_output bmm_input1 = symbols[n.args[0]] bmm_input2 = symbols[n.args[1]] dims_input1, counter = gen_tensor_dims(3, counter) dims_input2, counter = gen_tensor_dims(3, counter) inputs_dyn = Conj([ BinConstraintT(bmm_input1, Dyn, op_eq), BinConstraintT(bmm_input2, Dyn, op_eq), BinConstraintT(bmm_output, Dyn, op_eq) ]) input1_dyn = Conj([ BinConstraintT(bmm_input1, Dyn, op_eq), BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq) ]) input2_dyn = Conj([ BinConstraintT(bmm_input2, Dyn, op_eq), BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq) ]) consistency_constraints = [ BinConstraintD(dims_input1[0], dims_input2[0], op_consistency) ] batch_size, counter = gen_dvar(counter) inputs_are_tensors = Conj([ BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), BinConstraintT( bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq), *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0]) ]) return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter
def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter): """ Input and output shapes should be equal. Input should be consistent with the normalized_shape """ assert isinstance(n.args[0], Node) output, counter = gen_tvar(counter) symbols[n] = output input = symbols[n.args[0]] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintT(output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs) c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq), BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] + add_layer_norm_constraints(new_dims_rhs, list(module_instance.normalized_shape)) + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def cumsum_inference_rule(n: Node, symbols, constraints, counter): """ Input and output shapes should be equal We should verify that the index is valid """ assert isinstance(n.args[0], Node) arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"] assert isinstance(arg_1, int) output, counter = gen_tvar(counter) symbols[n] = output input = symbols[n.args[0]] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintT(output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims) c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq), BinConstraintT(output, TensorType(new_dims), op_eq)] + [range_check(arg_1, i)] + nat_constraints) c2.append(c_tensor_i) dyn_or_tensor = Disj([c1, Disj(c2)]) return [dyn_or_tensor], counter
def generate_calc_conv(constraint, counter): d, counter = gen_tensor_dims(4, counter) conv_result = TensorType([d[0], d[1], d[2], d[3]]) # the convolution result is a tensor of size 4 c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq) # the second dimension of the output is equal to the output channels c2 = Conj([ BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq) ]) # the input corresponds to the output in the first dimension of the convolution c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) c4, c5 = calc_last_two_dims(constraint, d) leq_constraints = Conj([ BinConstraintD(0, d[0], op_leq), BinConstraintD(0, d[1], op_leq), BinConstraintD(0, d[2], op_leq), BinConstraintD(0, d[3], op_leq) ]) return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter): assert isinstance(n.args[0], Node) my_conv, counter = gen_tvar(counter) symbols[n] = my_conv input_var = symbols[n.args[0]] # dim vars [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) # c1 = Matching(input_var, TensorType([d1, d2, d3, d4])) c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) # c2 = DConsistency(module_instance.in_channels, d2) c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency) c3 = CalcConv(my_conv, input_var, module_instance.out_channels, module_instance.kernel_size, module_instance.padding, module_instance.stride, module_instance.dilation, [d1, d2, d3, d4]) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) return [c1, c2, c3, *nat_constraints], counter
def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter): """ The output shape differs from the input shape in the last dimension """ assert isinstance(n.args[0], Node) embedding_dim = module_instance.embedding_dim # number embedding_output, counter = gen_tvar(counter) symbols[n] = embedding_output embedding_input = symbols[n.args[0]] input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK): new_dims, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims) # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq), BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def index_select_inference_rule(n: Node, symbols, constraints, counter): """ We constrain the second argument to a vector or Dyn. The output replaces the input with the shape of the vector at the position given by the index (first argument) """ # print(n.args) assert isinstance(n.args[0], Node) assert isinstance(n.args[1], int) assert isinstance(n.args[2], Node) index_select, counter = gen_tvar(counter) symbols[n] = index_select dims, counter = gen_tensor_dims(1, counter) # equality constraint is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select) for i in range(MAX_TENSOR_RANK)])]) c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) for i in range(MAX_TENSOR_RANK)])]) return [Disj([c2, c3])], counter
def expand_inference_rule(n: Node, symbols, constraints, counter): """ We generate the exact constraints as we do for tensor additions but we constraint the rank of this expression to be equal to len(n.args[1:]) so that only those cases get considered for the output """ assert isinstance(n.args[0], Node) # define the output for expand expand, counter = gen_tvar(counter) symbols[n] = expand # since we do not have two nodes here, we will construct an argument variable e1 = symbols[n.args[0]] e2, counter = gen_tvar(counter) e2_nat_constraints = [] for arg in n.args[1:]: assert isinstance(arg, Node) or isinstance(arg, int) if isinstance(arg, Node): assert isinstance(symbols[arg], DVar) e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq)) e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq) constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand) # constraint the output size dims, counter = gen_tensor_dims(len(n.args[1:]), counter) nat_constraints = gen_nat_constraints(dims) c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints] constraints += c return constraints, counter
def transform_get_item(constraint, counter): """ generate an equality of the form: t = [a1, ..., an] then generate constraints that check if the given index is valid given this particular tensor size. If the index is valid, generate a constraint to get the item Note that we already handled the Dyn input case in the previous step. Args: constraint: GetItem which assumes we are getting an item from a tensor (not Dyn) counter: variable tracking Returns: simplified constraints for GetItem """ dims, counter = gen_tensor_dims(constraint.tensor_size, counter) nat_constraints = gen_nat_constraints(dims) is_valid_index = valid_index(constraint.index, dims) all_constraints = [ BinConstraintT(constraint.input_var, TensorType(dims), op_eq), *nat_constraints, is_valid_index ] # if the index is valid, we generate a constraint for getting an item # otherwise this clause will have been UNSAT due to the wrong index if is_valid_index == T(): all_constraints.append( BinConstraintD(constraint.res, dims[constraint.index], op_eq)) return Conj(all_constraints), counter
def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): embedding_output, counter = gen_tvar(counter) symbols[n] = embedding_output embedding_input = symbols[n.args[0]] input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK): new_dims, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims) # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases c_tensor_i = Conj([ BinConstraintT(embedding_input, TensorType(new_dims), op_eq), BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq) ] + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def transform_get_item_tensor(constraint, counter): """ When the index is a tuple, then the output will be a tensor TODO: we have to check if this is the case for all HF models The cases we are covrering here are a tuple with one of: - slice with default argument - None None appends 1 to the input tensor dimensions so each occurrence of 'None' increases the rank by 1 slice with default arguments does not change the rank """ assert isinstance(constraint.index_tuple, tuple) # generate a result tensor of the expected size dims, counter = gen_tensor_dims(constraint.tensor_size, counter) nat_constraints = gen_nat_constraints(dims) # generate a place-holder list of the right rank # where "slice" does not contribute to the rank and "None" does none_c = constraint.index_tuple.count(None) resulting_tensor_dims = (none_c + len(dims)) * [None] dim_index = 0 for i in range(len(constraint.index_tuple)): # append 1 to the right location of the resulting tensor if constraint.index_tuple[i] is None: resulting_tensor_dims[i] = 1 elif constraint.index_tuple[i] == slice(None, None, None): pass else: raise NotImplementedError('Method not yet implemented') # append the remaining dimensions to the right location dim_index = 0 for i in range(len(resulting_tensor_dims)): if resulting_tensor_dims[i] is None: resulting_tensor_dims[i] = dims[dim_index] dim_index += 1 # check if the index is valid is_valid_index = valid_index_tensor(constraint.index_tuple, dims) # check if the resulting tensor is within bounds if len(resulting_tensor_dims) > 4: return F(), counter else: constraints = [ BinConstraintT(constraint.input_var, TensorType(dims), op_eq), BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), *nat_constraints, is_valid_index ] return Conj(constraints), counter
def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter): d, counter = gen_tensor_dims(n, counter) c1 = BinConstraintT(input, TensorType(d), op_eq) start_dim = n if start_dim == -1 else abs(start_dim) end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1 c2 = CalcProduct(start_dim, end_dim, flattened, d) nat_constraints = gen_nat_constraints(d) return Conj([c1, c2, *nat_constraints]), counter
def torch_linear_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) weight_dims, counter = gen_tensor_dims(2, counter) equality_constraint = BinConstraintT(n.args[1], TensorType(weight_dims), op_eq) constraints, counter = linear_constraints(n, weight_dims[0], weight_dims[1], symbols, counter) return [equality_constraint] + constraints, counter
def apply_padding(e1_var: TVar, e11: BinConstraintT, e2: BinConstraintT, e12: BinConstraintT, d2: List[DVar], d11: List[DVar], d12: List[DVar], counter: int): """ We are considering the possibility where one input has less dimensions than another input, so we apply padding to the broadcasted results Args: e1_var: Variable representing the first input where padding will be e11: constraint of the form e11 = Tensortype[d1, ..., dn] e2: constraint of the form e2 = Tensortype[d1, ..., dn] e12: constraint of the form e11 = Tensortype[d1, ..., dn] d2: Tensor variables for the second input d11: Tensor variables for the broadcasted first input d12: Tensor variables for the broadcasted second input counter: variable tracking Returns: A new constraint whose goal is to apply padding to the broadcasted result """ res = [] # pad the shorter input with None so we can pass it to the broadcasting helper function for i in range(1, len(d2)): d1, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12) e1 = BinConstraintT(e1_var, TensorType(d1), op_eq) simulate_padding = [None] * (len(d2) - i) assert len(simulate_padding + d1) == len(d2) broadcast_padding = [] # for every padding size, we also consider broadcasting for j in range((len(d2) - i)): broadcast_padding.append( broadcast_dim(simulate_padding, d2, d11, d12, j, True)) # we consider the possibilities for broadcasting for every dimension. Since we already # padded d1, we do not consider it while broadcasting all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding( d1, d2[(len(d2) - i):], d11[(len(d2) - i):], d12[(len(d2) - i):]) # combine all constraints into a conjunction c = Conj([ e1, e11, e2, e12, *broadcast_padding, all_broadcasting_possibilities, *nat_constraints ]) res.append(c) return Disj(res), counter
def embedding_inference_rule_functional(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) embedding_dim_weights = symbols[n.args[1]] # will treat this as a static shape. So we will not use matching. weight_dims, counter = gen_tensor_dims(2, counter) equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq) embedding_dim = weight_dims[1] constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter) return [equality_constraint] + constraints, counter
def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): """ Args: constraint: Greatest upper bound on tensors counter: variable tracking Returns: A set of equality constraints and DGreatestUpperBound constraints """ all_constraints = [] for i in range(1, MAX_TENSOR_RANK + 1): c = [] dims1, counter = gen_tensor_dims(i, counter) c1tensor = TensorType(dims1) dims2, counter = gen_tensor_dims(i, counter) c2tensor = TensorType(dims2) dims3, counter = gen_tensor_dims(i, counter) c3tensor = TensorType(dims3) c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq), BinConstraintT(constraint.rhs2, c2tensor, op_eq), BinConstraintT(constraint.res, c3tensor, op_eq)] + \ gen_nat_constraints(dims1 + dims2 + dims3) assert len(c3tensor.__args__) == len(c1tensor.__args__) == len( c2tensor.__args__) for i in range(len(c3tensor.__args__)): c.append( DGreatestUpperBound(c3tensor.__args__[i], c1tensor.__args__[i], c2tensor.__args__[i])) all_constraints.append(Conj(c)) return all_constraints, counter
def linear_constraints(n: Node, in_features, out_features, symbols, counter): linear_output, counter = gen_tvar(counter) symbols[n] = linear_output linear_input = symbols[n.args[0]] input_dyn = BinConstraintT(linear_input, Dyn, op_eq) output_dyn = BinConstraintT(linear_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) new_dims_rhs_2, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) c_tensor_i = Conj([ BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq) ] + add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter): assert isinstance(n.args[0], Node) maxpool, counter = gen_tvar(counter) symbols[n] = maxpool input_var = symbols[n.args[0]] # dim vars [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding, module_instance.stride, module_instance.dilation, [d1, d2, d3, d4]) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) return [c1, c2, *nat_constraints], counter
def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): """ Generate lists of DVar to represent tensor dimensions Args: num_tensors: the required number of tensors dim_size: the number of dimensions for each tensor counter: variable tracking Returns: A list of a list of tensor dimensions """ res = [] for _ in range(num_tensors): dims, counter = gen_tensor_dims(dim_size, counter) res.append(dims) return res, counter
def transform_transpose(constraint, counter): """ Similar to a sequence of two index-selects """ dims, counter = gen_tensor_dims(constraint.tensor_size, counter) is_valid_index1 = valid_index(constraint.index1, dims) is_valid_index2 = valid_index(constraint.index2, dims) new_dims = copy.deepcopy(dims) nat_constraints = gen_nat_constraints(dims) if is_valid_index1 == T() and is_valid_index2 == T(): new_dims[constraint.index1] = dims[constraint.index2] new_dims[constraint.index2] = dims[constraint.index1] transformed_constraint = Conj([ BinConstraintT(constraint.input_var, TensorType(dims), op_eq), *nat_constraints, is_valid_index1, is_valid_index2, BinConstraintT(constraint.output, TensorType(new_dims), op_eq) ]) return transformed_constraint, counter
def torch_dim_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) my_dim, counter = gen_dvar(counter) symbols[n] = my_dim input = symbols[n.args[0]] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintD(my_dim, Dyn, op_eq) c1 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) c_tensor_i = Conj([ BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), BinConstraintD(my_dim, i, op_eq) ]) c1.append(c_tensor_i) return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter
def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): output, counter = gen_tvar(counter) symbols[n] = output input = symbols[n.args[0]] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintT(output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs) c_tensor_i = Conj([ BinConstraintT(input, TensorType(new_dims_rhs), op_eq), BinConstraintT(output, TensorType(new_dims_rhs), op_eq) ] + add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def generate_calc_maxpool(constraint, counter): """ Transform maxpool constraints """ d, counter = gen_tensor_dims(4, counter) maxpool_result = TensorType([d[0], d[1], d[2], d[3]]) # the maxpool result is a tensor of size 4 c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq) # the input corresponds to the output in the first and second dimension of maxpool c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq) c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) c4, c5 = calc_last_two_dims(constraint, d) leq_constraints = Conj([ BinConstraintD(0, d[0], op_leq), BinConstraintD(0, d[1], op_leq), BinConstraintD(0, d[2], op_leq), BinConstraintD(0, d[3], op_leq) ]) return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter
def transform_index_select(constraint, counter): """ The constraints consider the given tensor size, checks if the index is valid and if so, generates a constraint for replacing the input dimension with the required dimension """ dims, counter = gen_tensor_dims(constraint.tensor_size, counter) is_valid_index = valid_index(constraint.index, dims) nat_constraints = gen_nat_constraints(dims) # if the index is valid then replace the input dimension with the new dimension # otherwise the dimension will not be replaced and the clause will contain False if is_valid_index == T(): new_dims = copy.deepcopy((dims)) new_dims[constraint.index] = constraint.dim_replace transformed_constraint = Conj([ BinConstraintT(constraint.input_var, TensorType(dims), op_eq), *nat_constraints, is_valid_index, BinConstraintT(constraint.output, TensorType(new_dims), op_eq) ]) # print(constraints) return transformed_constraint, counter
def generate_reshape(constraint, counter): """ Transform reshape constraints """ d, counter = gen_tensor_dims(4, counter) d1 = d[0] d2 = d[1] d3 = d[2] d4 = d[3] target = constraint.target.__args__ is_fully_static = all([d != Dyn for d in target]) # dynamic tensor c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq) c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq) c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq) c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq) c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq) d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq) d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq) d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq) d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq) d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq) d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq) d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq) d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq) nat_d1 = BinConstraintD(0, d1, op_leq) nat_d2 = BinConstraintD(0, d2, op_leq) nat_d3 = BinConstraintD(0, d3, op_leq) nat_d4 = BinConstraintD(0, d4, op_leq) if is_fully_static: # size 1 tensor c3_tensor1 = Disj([ d1_eq_dyn, (Conj([d1_neq_dyn, BinConstraintD(d1, Prod(target), op_eq)])) ]) all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) # size 2 tensor all_tensor_2 = Conj( [c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)]) # size 3 tensor all_tensor_3 = Conj( [c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)]) # size 4 tensor all_tensor_4 = Conj([ c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target) ]) return Conj([ Disj([ c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4 ]), nat_d1, nat_d2, nat_d3, nat_d4 ]), counter # then there must be exactly one occurrence of dyn else: new_target = [] for n in target: if n != Dyn: new_target.append(n) # tensor 1 c3_tensor1 = Disj([ d1_eq_dyn, (Conj([d1_neq_dyn, is_dim_div_by_target(new_target, d1)])) ]) all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) # tensor 2 c21 = Disj([d1_eq_dyn, d2_eq_dyn]) c22 = Conj([ d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2])) ]) all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])]) # tensor 3 c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn]) c32 = Conj([ d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3])) ]) all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])]) # tensor 4 c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn]) c42 = Conj([ d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4])) ]) all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])]) return Conj([ Disj([ c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4 ]), nat_d1, nat_d2, nat_d3, nat_d4 ]), counter
def neq_inference_rule(n: Node, symbols, constraints, counter): """ Translates to inconsistent in gradual types. To prove inequality, we should prove that tensors are either different sizes or disagree on at least one dimension This is a WIP (works when the condition is false. We are working on making this operation work when the condition is true as well) """ assert isinstance(n.args[0], Node) assert isinstance(n.args[1], tuple) # implementing for size 3 and 4 if len(n.args[1]) == 3: assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int) assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int) assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int) lhs = symbols[n.args[0]] b, counter = gen_tensor_dims(4, counter) input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq) d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] # dimensions not equal my_ne, counter = gen_bvar(counter) neq_1 = BinConstraintD(d1, b[0], op_neq) neq_2 = BinConstraintD(d2, b[1], op_neq) neq_3 = BinConstraintD(d3, b[2], op_neq) # dimensions inconsistent dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1]) dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2]) dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3]) dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3]) # we are covering size 3 and 4 only for now ne_constraint = Conj([input_is_size3, dims_inconsistent]) my_ne, counter = gen_bvar(counter) equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) elif len(n.args[1]) == 4: assert isinstance(n.args[1][0], Node) or isinstance(n.args[1][0], int) assert isinstance(n.args[1][1], Node) or isinstance(n.args[1][1], int) assert isinstance(n.args[1][2], Node) or isinstance(n.args[1][2], int) assert isinstance(n.args[1][3], Node) or isinstance(n.args[1][3], int) lhs = symbols[n.args[0]] b1, counter = gen_dvar(counter) b2, counter = gen_dvar(counter) b3, counter = gen_dvar(counter) b4, counter = gen_dvar(counter) input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq) d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]] # dimensions not equal my_ne, counter = gen_bvar(counter) neq_1 = BinConstraintD(d1, b1, op_neq) neq_2 = BinConstraintD(d2, b2, op_neq) neq_3 = BinConstraintD(d3, b3, op_neq) neq_4 = BinConstraintD(d4, b4, op_neq) # dimensions to inconsistent dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1]) dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2]) dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3]) dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4]) dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4]) ne_constraint = Conj([input_is_size4, dims_inconsistent]) my_ne, counter = gen_bvar(counter) equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) else: raise NotImplementedError('Method not yet implemented') return [equality_constraint], counter