def generate_calc_product(constraint, counter): """ Transform flatten constraints """ start = constraint.start end = constraint.end dims = constraint.dims_to_flatten flattened = constraint.flattened n = len(constraint.dims_to_flatten) # this will be evaluated right here boundary_check = (0 <= start and start < end and end <= n) c_boundary = T() if boundary_check else F() lhs = dims[0:start] rhs = dims[end:] mid = dims[start:end] all_possibilities = generate_all_int_dyn_dim_possibilities(mid) all_constraints = [] for p in all_possibilities: p = list(p) # this tells us there is a dynamic variable contains_dyn = not (all([constraint.op == op_neq for constraint in p])) if contains_dyn: mid_var = [Dyn] total_constraints = lhs + mid_var + rhs if len(total_constraints) > 4: all_constraints.append(F()) else: all_constraints.append( Conj([ BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq) ] + p)) else: new_var, counter = gen_dvar(counter) mid_eq_prod = Conj([ BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq) ]) mid_var = [new_var] total_constraints = lhs + mid_var + rhs if len(total_constraints) > 4: all_constraints.append(F()) else: all_constraints.append( Conj([ BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod ] + p)) return Conj([Disj(all_constraints), c_boundary]), counter
def range_check(i, n): """ Checks if an index i is within range of a size n list Args: i: index n: list size Returns: Boolean """ if i >= 0: return T() if i < n else F() else: return T() if i >= n else F()
def transform_get_item_tensor(constraint, counter): """ When the index is a tuple, then the output will be a tensor TODO: we have to check if this is the case for all HF models The cases we are covrering here are a tuple with one of: - slice with default argument - None None appends 1 to the input tensor dimensions so each occurrence of 'None' increases the rank by 1 slice with default arguments does not change the rank """ assert isinstance(constraint.index_tuple, tuple) # generate a result tensor of the expected size dims, counter = gen_tensor_dims(constraint.tensor_size, counter) nat_constraints = gen_nat_constraints(dims) # generate a place-holder list of the right rank # where "slice" does not contribute to the rank and "None" does none_c = constraint.index_tuple.count(None) resulting_tensor_dims = (none_c + len(dims)) * [None] dim_index = 0 for i in range(len(constraint.index_tuple)): # append 1 to the right location of the resulting tensor if constraint.index_tuple[i] is None: resulting_tensor_dims[i] = 1 elif constraint.index_tuple[i] == slice(None, None, None): pass else: raise NotImplementedError('Method not yet implemented') # append the remaining dimensions to the right location dim_index = 0 for i in range(len(resulting_tensor_dims)): if resulting_tensor_dims[i] is None: resulting_tensor_dims[i] = dims[dim_index] dim_index += 1 # check if the index is valid is_valid_index = valid_index_tensor(constraint.index_tuple, dims) # check if the resulting tensor is within bounds if len(resulting_tensor_dims) > 4: return F(), counter else: constraints = [ BinConstraintT(constraint.input_var, TensorType(dims), op_eq), BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), *nat_constraints, is_valid_index ] return Conj(constraints), counter
def valid_index(index, dims): """ Given a list of dimensions, checks if an index is valid in the list """ try: dims[index] return T() except IndexError: return F()
def valid_index_tensor(index, dims): """ if the slice instances exceed the length of the dimensions then this is a type error so we return False """ slice_count = 0 for s in index: if isinstance(s, slice): slice_count += 1 if slice_count > len(dims): return F() else: return T()
def add_layer_norm_constraints(input_dim, normalized_dim): """ The constraints say that the type has te form: [*, 1024, 1024] while the normalized_dim have the form [1024, 1024] Args: input_dim: Input shape of layer norm normalized_dim: normalized_dim parameter of the module instance """ # in this case we return false since there's a pattern mismatch if len(normalized_dim) > len(input_dim): return [F()] else: constraints = [] for i, n in zip(reversed(input_dim), reversed(normalized_dim)): constraints.append(BinConstraintD(i, n, op_consistency)) return constraints