def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): embedding_output, counter = gen_tvar(counter) symbols[n] = embedding_output embedding_input = symbols[n.args[0]] input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK): new_dims, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims) # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases c_tensor_i = Conj([ BinConstraintT(embedding_input, TensorType(new_dims), op_eq), BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq) ] + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def transpose_inference_rule(n: Node, symbols, constraints, counter): """ Can be considered as a sequence of two index selects, so we generate constraints accordingly """ assert isinstance(n.args[0], Node) assert isinstance(n.args[1], int) assert isinstance(n.args[2], int) output, counter = gen_tvar(counter) symbols[n] = output from_arg = symbols[n.args[0]] assert isinstance(from_arg, TVar) # input and output are dyn is_dyn = Conj([ BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq) ]) # or input is a tensor and we actually do the replacement c3 = Disj([ Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK) ]) return [Disj([is_dyn, c3])], counter
def cumsum_inference_rule(n: Node, symbols, constraints, counter): """ Input and output shapes should be equal We should verify that the index is valid """ assert isinstance(n.args[0], Node) arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"] assert isinstance(arg_1, int) output, counter = gen_tvar(counter) symbols[n] = output input = symbols[n.args[0]] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintT(output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims) c_tensor_i = Conj([ BinConstraintT(input, TensorType(new_dims), op_eq), BinConstraintT(output, TensorType(new_dims), op_eq) ] + [range_check(arg_1, i)] + nat_constraints) c2.append(c_tensor_i) dyn_or_tensor = Disj([c1, Disj(c2)]) return [dyn_or_tensor], counter
def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter): """ Input and output sizes should be the same except for the last dimension If the input is Dyn, then so should the output """ assert isinstance(n.args[0], Node) linear_output, counter = gen_tvar(counter) symbols[n] = linear_output linear_input = symbols[n.args[0]] input_dyn = BinConstraintT(linear_input, Dyn, op_eq) output_dyn = BinConstraintT(linear_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) new_dims_rhs_2, counter = gen_tensor_dims(i, counter) # Todo: add back natural number constraints nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) c_tensor_i = Conj([ BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq) ] + add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, module_instance) + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def equality_inference_rule(n: Node, symbols, constraints, counter): """ We generate the constraint: input = output """ output, counter = gen_tvar(counter) symbols[n] = output if isinstance(n.args[0], Node): input = symbols[n.args[0]] if isinstance(input, TVar): return [BinConstraintT(input, output, op_eq)], counter # then we have dimension variables else: for arg in n.args: assert isinstance(symbols[arg], DVar) my_size = [symbols[arg] for arg in n.args] return [BinConstraintT(output, TensorType(my_size), op_eq)], counter elif isinstance(n.args[0], tuple): # then the tuple is the size assert len(n.args[0]) <= 4 my_size = [symbols[arg] for arg in n.args[0]] return [BinConstraintT(output, TensorType(my_size), op_eq)], counter else: raise NotImplementedError('Method not yet implemented')
def flatten_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) # generate the new variable flattened, counter = gen_tvar(counter) symbols[n] = flattened input = symbols[n.args[0]] # set the default start and end dims start_dim = 1 end_dim = -1 if len(n.args) > 1: assert isinstance(n.args[1], int) start_dim = n.args[1] if len(n.args) > 2: assert isinstance(n.args[2], int) end_dim = n.args[2] c1 = BinConstraintT(input, Dyn, op_eq) c2 = BinConstraintT(flattened, Dyn, op_eq) both_dyn = Conj([c1, c2]) const = [] for i in range(1, MAX_TENSOR_RANK + 1): c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter) const.append(c) return [Disj([both_dyn, *const])], counter
def generate_constraints_node(self, n: Node, counter): """ Generate constraints the given node: Currently supported operations: - Reshape - Add - conv2d """ if n.op == 'placeholder': x, counter = gen_tvar(counter) self.symbol_dict[n] = x c1 = BinConstraintT(n.type, x, op_precision) c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq) return [c1, c2], counter elif n.op == 'call_function': if n.target == getattr: assert getattr in _INFERENCE_RULES return _INFERENCE_RULES[n.target](n, self.traced, self.symbol_dict, self.constraints) elif n.target in _INFERENCE_RULES: return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) else: # print(n) raise RuntimeError( f'No inference rule registered for target {n.target}!') elif n.op == 'call_module': module_instance = self.traced.get_submodule(n.target) if type(module_instance) in _INFERENCE_RULES: return _INFERENCE_RULES[type(module_instance)]( n, module_instance, self.symbol_dict, self.constraints, counter) else: raise RuntimeError( f'No inference rule registered for class {type(module_instance)}!' ) elif n.op == 'call_method': if n.target in _INFERENCE_RULES: return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) else: raise RuntimeError( f'No inference rule registered for target {n.target}!') # TODO: verify that no constraint should be generated here elif n.op == 'get_attr': return [], counter elif n.op == 'output': return [], counter else: raise NotImplementedError(f"Method {n.op} not yet implemented")
def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter): """ Input and output shapes should be equal. Input should be consistent with the normalized_shape """ assert isinstance(n.args[0], Node) output, counter = gen_tvar(counter) symbols[n] = output input = symbols[n.args[0]] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintT(output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs) c_tensor_i = Conj([ BinConstraintT(input, TensorType(new_dims_rhs), op_eq), BinConstraintT(output, TensorType(new_dims_rhs), op_eq) ] + add_layer_norm_constraints( new_dims_rhs, list(module_instance.normalized_shape)) + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter): assert isinstance(n.args[0], Node) my_conv, counter = gen_tvar(counter) symbols[n] = my_conv input_var = symbols[n.args[0]] # dim vars [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) # c1 = Matching(input_var, TensorType([d1, d2, d3, d4])) c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) # c2 = DConsistency(module_instance.in_channels, d2) c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency) c3 = CalcConv(my_conv, input_var, module_instance.out_channels, module_instance.kernel_size, module_instance.padding, module_instance.stride, module_instance.dilation, [d1, d2, d3, d4]) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) return [c1, c2, c3, *nat_constraints], counter
def index_select_inference_rule(n: Node, symbols, constraints, counter): """ We constrain the second argument to a vector or Dyn. The output replaces the input with the shape of the vector at the position given by the index (first argument) """ # print(n.args) assert isinstance(n.args[0], Node) assert isinstance(n.args[1], int) assert isinstance(n.args[2], Node) index_select, counter = gen_tvar(counter) symbols[n] = index_select dims, counter = gen_tensor_dims(1, counter) # equality constraint is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select) for i in range(MAX_TENSOR_RANK)])]) c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) for i in range(MAX_TENSOR_RANK)])]) return [Disj([c2, c3])], counter
def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter): """ The output shape differs from the input shape in the last dimension """ assert isinstance(n.args[0], Node) embedding_dim = module_instance.embedding_dim # number embedding_output, counter = gen_tvar(counter) symbols[n] = embedding_output embedding_input = symbols[n.args[0]] input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) c2 = [] for i in range(1, MAX_TENSOR_RANK): new_dims, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims) # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases c_tensor_i = Conj([ BinConstraintT(embedding_input, TensorType(new_dims), op_eq), BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq) ] + nat_constraints) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter
def arange_inference_rule(n: Node, symbols, constraints, counter): start = 0 step = 1 if len(n.args) == 1: end = symbols[n.args[0]] else: raise NotImplementedError('Not yet implemented') # int((end - start) / step) d1, counter = gen_dvar(counter) size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq) arange, counter = gen_tvar(counter) symbols[n] = arange # either the a parameter is a number or it is Dyn c1 = Disj([BinConstraintD(end, Dyn, op_eq), BinConstraintD(start, Dyn, op_eq), BinConstraintD(step, Dyn, op_eq)]) c2 = BinConstraintD(d1, Dyn, op_eq) both_dyn = Conj([c1, c2]) c11 = Conj([BinConstraintD(end, Dyn, op_neq), BinConstraintD(start, Dyn, op_neq), BinConstraintD(step, Dyn, op_neq)]) c22 = BinConstraintD(d1, Dyn, op_neq) both_numbers = Conj([c11, c22, size_constraint]) return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter
def gt_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) or isinstance(n.args[0], int) assert isinstance(n.args[1], Node) or isinstance(n.args[1], int) # We make sure this node will not be used again. We do not # generate a constraint about that node. Only about the operands. e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): if isinstance(e1, TVar) and isinstance(e2, TVar): gt_tensor, counter = gen_tvar(counter) symbols[n] = gt_tensor return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor) elif isinstance(e1, DVar) and isinstance(e2, DVar): # This is meant to be used for flow analysis only gt_constraint = BinConstraintD(e1, e2, op_gt) my_gt, counter = gen_bvar(counter) equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) return [equality_constraint], counter else: raise RuntimeError('Sort Mismatch') elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): if isinstance(e1, DVar): # This is meant to be used for flow analysis only gt_constraint = BinConstraintD(e1, e2, op_gt) my_gt, counter = gen_bvar(counter) equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) return [equality_constraint], counter elif isinstance(e1, TVar) and isinstance(e2, int): # then we made the wrong assumption about the argument being a tensor # so we should fix the assumption warnings.warn( f'Made the wrong assumption for node {n}. Correctness not guaranteed.' ) new_e1, counter = gen_dvar(counter) symbols[n.args[0]] = new_e1 symbols[n.args[0]] gt_constraint = BinConstraintD(new_e1, e2, op_gt) my_gt, counter = gen_bvar(counter) equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) return [equality_constraint], counter else: raise NotImplementedError('Method not yet implemented') else: raise NotImplementedError('Method not yet implemented')
def bmm_inference_rule(n: Node, symbols, constraints, counter): """ Constraints that match the input to a size 3 tensor and switch the dimensions according to the rules of batch multiplication """ assert isinstance(n.args[0], Node) assert isinstance(n.args[1], Node) bmm_output, counter = gen_tvar(counter) symbols[n] = bmm_output bmm_input1 = symbols[n.args[0]] bmm_input2 = symbols[n.args[1]] dims_input1, counter = gen_tensor_dims(3, counter) dims_input2, counter = gen_tensor_dims(3, counter) inputs_dyn = Conj([ BinConstraintT(bmm_input1, Dyn, op_eq), BinConstraintT(bmm_input2, Dyn, op_eq), BinConstraintT(bmm_output, Dyn, op_eq) ]) input1_dyn = Conj([ BinConstraintT(bmm_input1, Dyn, op_eq), BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq) ]) input2_dyn = Conj([ BinConstraintT(bmm_input2, Dyn, op_eq), BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq) ]) consistency_constraints = [ BinConstraintD(dims_input1[0], dims_input2[0], op_consistency) ] batch_size, counter = gen_dvar(counter) inputs_are_tensors = Conj([ BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), BinConstraintT( bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq), *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0]) ]) return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter
def getitem_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) # dimension output case if isinstance(n.args[1], int): # create and store the new dimension variable get_item_output, counter = gen_dvar(counter) symbols[n] = get_item_output # retreive arg variables get_item_arg = symbols[n.args[0]] assert isinstance(get_item_arg, TVar) # if the input is dynamic, we accept any index and return # a dynamic dimension as output input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) output_dyn = BinConstraintD(get_item_output, Dyn, op_eq) c1 = Conj([input_dyn, output_dyn]) # if the input is a tensor, # generate a getItem constraint which will be expanded based on the # tensor dimension. c2 = [ GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK) ] # since the output is a dimension, we make sure it's a natural number # added as a conjunction to the disjuction of c2 c3 = BinConstraintD(0, get_item_output, op_leq) return [Disj([c1, Conj([Disj(c2), c3])])], counter # tensor output case elif isinstance(n.args[1], tuple): # create and store the new tensor variable get_item_output, counter = gen_tvar(counter) symbols[n] = get_item_output # retreive arg variables get_item_arg = symbols[n.args[0]] assert isinstance(get_item_arg, TVar) input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment] c1 = Conj([input_dyn, output_dyn]) c2 = [ GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK) ] # type: ignore[misc] return [Disj([c1, *c2])], counter else: raise RuntimeError('Method not yet implemented')
def to_inference_rule(n: Node, symbols, constraints, counter): """ We generate the constraint: input = output """ assert isinstance(n.args[0], Node) output, counter = gen_tvar(counter) symbols[n] = output input = symbols[n.args[0]] return [BinConstraintT(input, output, op_eq)], counter
def full_inference_rule(n: Node, symbols, constraints, counter): full, counter = gen_tvar(counter) symbols[n] = full res = [] assert isinstance(n.args[0], Iterable) for arg in n.args[0]: res.append(symbols[arg]) c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type] return [c], counter
def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter): """ Input and output shapes should be equal. """ assert isinstance(n.args[0], Node) output, counter = gen_tvar(counter) symbols[n] = output input = symbols[n.args[0]] assert isinstance(input, TVar) return [BinConstraintT(input, output, op_eq)], counter
def expand_inference_rule(n: Node, symbols, constraints, counter): """ We generate the exact constraints as we do for tensor additions but we constraint the rank of this expression to be equal to len(n.args[1:]) so that only those cases get considered for the output """ assert isinstance(n.args[0], Node) # define the output for expand expand, counter = gen_tvar(counter) symbols[n] = expand # since we do not have two nodes here, we will construct an argument variable e1 = symbols[n.args[0]] e2, counter = gen_tvar(counter) e2_nat_constraints = [] for arg in n.args[1:]: assert isinstance(arg, Node) or isinstance(arg, int) if isinstance(arg, Node): assert isinstance(symbols[arg], DVar) e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq)) e2_constraint = BinConstraintT( e2, TensorType([ arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:] ]), op_eq) constraints, counter = gen_broadcasting_constraints( e1, e2, symbols, counter, expand) # constraint the output size dims, counter = gen_tensor_dims(len(n.args[1:]), counter) nat_constraints = gen_nat_constraints(dims) c = [ BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints ] constraints += c return constraints, counter
def add_inference_rule(n: Node, symbols, constraints, counter): if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): if isinstance(symbols[n.args[0]], TVar) and isinstance( symbols[n.args[1]], TVar): my_add, counter = gen_tvar(counter) symbols[n] = my_add e1 = symbols[n.args[0]] e2 = symbols[n.args[1]] return gen_broadcasting_constraints(e1, e2, symbols, counter, my_add) else: raise NotImplementedError('Method not yet implemented') elif isinstance(n.args[0], Node) and isinstance(n.args[1], int): if isinstance(symbols[n.args[0]], TVar): my_add, counter = gen_tvar(counter) symbols[n] = my_add e1 = symbols[n.args[0]] return [BinConstraintT(my_add, e1, op_eq)], counter elif isinstance(symbols[n.args[0]], DVar): my_add, counter = gen_dvar(counter) symbols[n] = my_add e1 = symbols[n.args[0]] # we will propagate the runtime value here since this is regular addition c = Conj([ BinConstraintD(my_add, BinConstraintD(e1, n.args[1], op_add), op_eq), BinConstraintD(0, my_add, op_leq) ]) return [c], counter else: raise NotImplementedError('Method not yet implemented') else: # TODO generate add constraints for scalar addition raise NotImplementedError('Addition not yet implemented')
def mul_inference_rule(n: Node, symbols, constraints, counter): my_mul, counter = gen_tvar(counter) symbols[n] = my_mul # since in this case, we have scalar multiplication # the input shape should be the same as the output shape if isinstance(n.args[0], Node) and isinstance(n.args[1], float): # retrieve arg variables e1 = symbols[n.args[0]] return [BinConstraintT(my_mul, e1, op_eq)], counter else: raise NotImplementedError('Case not yet implemented')
def reshape_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) # generate the new variable my_reshape, counter = gen_tvar(counter) symbols[n] = my_reshape src_var = symbols[n.args[0]] t2 = n.args[1] t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr] c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr] c2 = CanReshape(src_var, t2_type) return [c1, c2], counter
def ne_inference_rule(n: Node, symbols, constraints, counter): """ We generate the same constraints as we do for addition. We assume the arguments can only be tensors here, unlike addition where we have scalar addition. """ # create and store the new variable my_add, counter = gen_tvar(counter) symbols[n] = my_add assert isinstance(n.args[0], Node) assert isinstance(n.args[1], Node) return gen_broadcasting_constraints(symbols[n.args[0]], symbols[n.args[1]], symbols, counter, my_add)
def add_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) assert isinstance(n.args[1], Node) # create and store the new variable my_add, counter = gen_tvar(counter) symbols[n] = my_add # retrieve arg variables e1 = symbols[n.args[0]] e2 = symbols[n.args[1]] # additional vars that don't correspond to expressions e11, counter = gen_tvar(counter) e22, counter = gen_tvar(counter) # generate constraints c1 = TGreatestUpperBound(my_add, e11, e22) c2 = ApplyBroadcasting(e11, e22, e1, e2) c3 = BinConstraintT(e11, e22, op_consistency) # store constraints return [c1, c2, c3], counter
def get_attr_inference_rule(n: Node, symbols, constraints, counter): """ If the attribute is "device" then the tensor shape is preserved """ assert isinstance(n.args[0], Node) assert isinstance(n.args[1], str) output, counter = gen_tvar(counter) symbols[n] = output input = symbols[n.args[0]] attr = n.args[1] if attr == 'device': return [BinConstraintT(input, output, op_eq)], counter else: raise NotImplementedError('Not yet implemented')
def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter): assert isinstance(n.args[0], Node) maxpool, counter = gen_tvar(counter) symbols[n] = maxpool input_var = symbols[n.args[0]] # dim vars [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding, module_instance.stride, module_instance.dilation, [d1, d2, d3, d4]) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) return [c1, c2, *nat_constraints], counter
def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter): assert isinstance(n.args[0], Node) avg_pool, counter = gen_tvar(counter) symbols[n] = avg_pool input_var = symbols[n.args[0]] # dim vars d1, counter = gen_dvar(counter) d2, counter = gen_dvar(counter) d3, counter = gen_dvar(counter) d4, counter = gen_dvar(counter) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq) return [c1, c2, *nat_constraints], counter
def type_inference_rule(n: Node, symbols, constraints, counter): """ We generate the constraint: input = output """ assert isinstance(n.args[0], Node) assert isinstance(n.args[1], Node) output, counter = gen_tvar(counter) symbols[n] = output from_arg = symbols[n.args[0]] to_arg = symbols[n.args[1]] assert isinstance(from_arg, TVar) assert isinstance(to_arg, TVar) return [BinConstraintT(from_arg, to_arg, op_consistency), BinConstraintT(output, to_arg, op_eq)], counter
def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter): assert isinstance(n.args[0], Node) # generate the new variable batchnorm_output, counter = gen_tvar(counter) symbols[n] = batchnorm_output batchnorm_input = symbols[n.args[0]] # dim vars d1, counter = gen_dvar(counter) d2, counter = gen_dvar(counter) d3, counter = gen_dvar(counter) d4, counter = gen_dvar(counter) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching) c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq) return [c1, c2, *nat_constraints], counter
def view_inference_rule(n: Node, symbols, constraints, counter): """ Similar to reshape but with an extra condition on the strides """ assert isinstance(n.args[0], Node) # generate the new variable my_view, counter = gen_tvar(counter) symbols[n] = my_view src_var = symbols[n.args[0]] t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr] c1 = BinConstraintT(my_view, t2_type, op_eq) c2 = CanReshape(src_var, t2_type) # TODO: add the extra check mentioned here: # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view return [c1, c2], counter