def eval_keys(self, ctx): key_pairs = [] if not self.keys: # there is nothing to match with... return z3.BoolVal(False) for index, (key_expr, key_type) in enumerate(self.keys): key_eval = ctx.resolve_expr(key_expr) key_sort = key_eval.sort() key_match = z3.Const(f"{self.name}_table_key_{index}", key_sort) if key_type == "exact": # Just a simple comparison, nothing special key_pairs.append(key_eval == key_match) elif key_type == "lpm": # I think this can be arbitrarily large... # If the shift exceeds the bit width, everything will be zero # but that does not matter # TODO: Test this? mask_var = z3.BitVec(f"{self.name}_table_mask_{index}", key_sort) lpm_mask = z3.BitVecVal(2**key_sort.size() - 1, key_sort) << mask_var match = (key_eval & lpm_mask) == (key_match & lpm_mask) key_pairs.append(match) elif key_type == "ternary": # Just apply a symbolic mask, any zero bit is a wildcard # TODO: Test this? mask = z3.Const(f"{self.name}_table_mask_{index}", key_sort) # this is dumb... if isinstance(key_sort, z3.BoolSortRef): match = z3.And(key_eval, mask) == z3.And(key_match, mask) else: match = (key_eval & mask) == (key_match & mask) key_pairs.append(match) elif key_type == "range": # Pick an arbitrary minimum and maximum within the bit range # the minimum must be strictly lesser than the max # I do not think a match is needed? # TODO: Test this? min_key = z3.Const(f"{self.name}_table_min_{index}", key_sort) max_key = z3.Const(f"{self.name}_table_max_{index}", key_sort) match = z3.And(z3.ULE(min_key, key_eval), z3.UGE(max_key, key_eval)) key_pairs.append(z3.And(match, z3.ULT(min_key, max_key))) elif key_type == "optional": # As far as I understand this is just a wildcard for control # plane purposes. Semantically, there is no point? # TODO: Test this? key_pairs.append(z3.BoolVal(True)) elif key_type == "selector": # Selectors are a deep rabbit hole # This rabbit hole does not yet make sense to me # FIXME: Implement # will intentionally fail if no implementation is present # impl = self.properties["implementation"] # impl_extern = self.prog_state.resolve_reference(impl) key_pairs.append(z3.BoolVal(True)) else: # weird key, might be some specific specification raise RuntimeError(f"Key type {key_type} not supported!") return z3.And(key_pairs)
def __init__(self): super(packet_in, self).__init__("packet_in", type_params={}, methods=[]) self.pkt_cursor = z3.BitVecVal(0, 32) # attach the methods self.locals = {} # EXTRACT # class extract_1(P4Method): hdr_param_name = "hdr" def extract_hdr(self, ctx, merged_args): hdr = merged_args[self.hdr_param_name].p4_val # apply the local and parent extern type ctxs for type_name, p4_type in self.extern_ctx.items(): ctx.add_type(type_name, ctx.resolve_type(p4_type)) for type_name, p4_type in self.type_ctx.items(): ctx.add_type(type_name, ctx.resolve_type(p4_type)) # advance the header index if a next field has been accessed hdr_stack = detect_hdr_stack_next(ctx, hdr) if hdr_stack: compare = hdr_stack.locals[ "nextIndex"] >= hdr_stack.locals["size"] if z3.simplify(compare) == z3.BoolVal(True): raise ParserException("Index out of bounds!") # grab the hdr value hdr_expr = ctx.resolve_expr(hdr) hdr_expr.activate() bind_const = z3.Const(f"{self.name}_{self.hdr_param_name}", hdr_expr.z3_type) hdr_expr.bind(bind_const) # advance the stack, if it exists if hdr_stack: hdr_stack.locals["lastIndex"] = hdr_stack.locals[ "nextIndex"] hdr_stack.locals["nextIndex"] += 1 self.call_counter += 1 def __call__(self, ctx, *args, **kwargs): merged_args = merge_parameters(self.params, *args, **kwargs) # this means default expressions have been used, no input if not merged_args: return self.extract_hdr(ctx, merged_args) class extract_2(extract_1): hdr_param_name = "variableSizeHeader" def __call__(self, ctx, *args, **kwargs): merged_args = merge_parameters(self.params, *args, **kwargs) # this means default expressions have been used, no input if not merged_args: return self.extract_hdr(ctx, merged_args) field_size = ctx.resolve_expr( merged_args["variableFieldSizeInBits"].p4_val) # self.pkt_cursor += field_size extract_1_var = extract_1(name="extract", params=[ P4Parameter("out", "hdr", "T", None), ], type_params=(None, [ "T", ])) extract_2_var = extract_2(name="extract", params=[ P4Parameter("out", "variableSizeHeader", "T", None), P4Parameter("in", "variableFieldSizeInBits", z3.BitVecSort(32), None), ], type_params=(None, [ "T", ])) self.locals.setdefault("extract", []).append(extract_1_var) self.locals.setdefault("extract", []).append(extract_2_var) # LOOKAHEAD # lookahead_var = P4Method(name="lookahead", params=[], type_params=("T", [ "T", ])) self.locals.setdefault("lookahead", []).append(lookahead_var) # LENGTH # self.locals["length"] = z3.BitVec(f"{self.name}_length", 32) # ADVANCE # class advance(P4Method): def eval_callable(self, ctx, merged_args, var_buffer): # self.pkt_cursor += merged_args["sizeInBits"] pass advance_var = advance( name="advance", params=[P4Parameter("in", "sizeInBits", z3.BitVecSort(32), None)], type_params=(None, [])) self.locals.setdefault("advance", []).append(advance_var)