def test_ignore_duplicates(self): state_group_info = get_state_group_info( """state_and_packet.state_group_0_state_0 = 1; state_and_packet.state_group_0_state_0 += 1;""" ) self.assertDictEqual(state_group_info, {'0': set('0')})
def test_multiple_groups(self): state_group_info = get_state_group_info( """state_and_packet.state_group_0_state_0, state_and_packet.state_group_0_state_1, state_and_packet.state_group_1_state_0, state_and_packet.state_group_1_state_1""") expected_info = { '0': set(['0', '1']), '1': set(['0', '1']), } self.assertDictEqual(state_group_info, expected_info)
def __init__(self, program_file, stateful_alu_file, stateless_alu_file, num_pipeline_stages, num_alus_per_stage, sketch_name, parallel_sketch, constant_set, synthesized_allocation=False, pkt_fields_to_check=[], state_groups_to_check=[], state_dependency=[], input_packet=[]): self.program_file = program_file self.stateful_alu_file = stateful_alu_file self.stateless_alu_file = stateless_alu_file self.num_pipeline_stages = num_pipeline_stages self.num_alus_per_stage = num_alus_per_stage self.sketch_name = sketch_name self.parallel_sketch = parallel_sketch self.constant_set = constant_set self.synthesized_allocation = synthesized_allocation program_content = Path(program_file).read_text() self.num_fields_in_prog = get_num_pkt_fields(program_content) self.num_state_groups = len(get_state_group_info(program_content)) if not input_packet: assert self.num_fields_in_prog <= num_alus_per_stage, ( 'Number of fields in program %d is ' 'greater than number of alus per stage %d. Try increasing ' 'number of alus per stage.' % (self.num_fields_in_prog, num_alus_per_stage)) else: assert len(input_packet) <= num_alus_per_stage, ( 'Number of input fields in program %d is' 'greater than number of alus per stage %d. Try increasing ' 'number of alus per stage.' % (len(input_packet), num_alus_per_stage)) # Guarantee that # of pkt_fields_to_check is less than or equal # to the num_alus_per_stage if pkt_fields_to_check is not None: assert len(pkt_fields_to_check) <= num_alus_per_stage, ( 'Number of checked fields in program %d is ' 'greater than number of alus per stage %d. ' 'Try increasing number of alus per stage.' % (len(pkt_fields_to_check), num_alus_per_stage)) # Initialize jinja2 environment for templates self.jinja2_env = Environment(loader=FileSystemLoader([ path.join(path.dirname(__file__), './templates'), path.join(os.getcwd(), stateless_alu_file[:stateless_alu_file.rfind('/')]), '.', '/' ]), undefined=StrictUndefined, trim_blocks=True, lstrip_blocks=True) if not pkt_fields_to_check and not state_groups_to_check: pkt_fields_to_check = list(range(self.num_fields_in_prog)) state_groups_to_check = list(range(self.num_state_groups)) elif not pkt_fields_to_check and state_groups_to_check: pkt_fields_to_check = [] elif pkt_fields_to_check and not state_groups_to_check: state_groups_to_check = [] # Differentiate between using default pkt input vs. specify pkt input if not input_packet: input_packet = list(range(self.num_fields_in_prog)) # Create an object for sketch generation self.sketch_generator = SketchGenerator( sketch_name=sketch_name, num_pipeline_stages=num_pipeline_stages, num_alus_per_stage=num_alus_per_stage, num_phv_containers=num_alus_per_stage, num_state_groups=self.num_state_groups, num_fields_in_prog=self.num_fields_in_prog, pkt_fields_to_check=pkt_fields_to_check, state_dependency=state_dependency, state_groups_to_check=state_groups_to_check, jinja2_env=self.jinja2_env, stateful_alu_file=stateful_alu_file, stateless_alu_file=stateless_alu_file, constant_set=constant_set, synthesized_allocation=synthesized_allocation, input_packet=input_packet)
def main(argv): parser = argparse.ArgumentParser(description='Iterative solver.') parser.add_argument( 'program_file', help='Program specification in .sk file') parser.add_argument('stateful_alu_file', help='Stateful ALU file to use.') parser.add_argument( 'stateless_alu_file', help='Stateless ALU file to use.') parser.add_argument( 'num_pipeline_stages', type=int, help='Number of pipeline stages') parser.add_argument( 'num_alus_per_stage', type=int, help='Number of stateless/stateful ALUs per stage') parser.add_argument( 'constant_set', type=str, help='The content in the constant_set\ and the format will be like 0,1,2,3\ and we will calculate the number of\ comma to get the size of it') parser.add_argument( 'max_input_bit', type=int, help='The maximum input value in bits') parser.add_argument( '--pkt-fields', type=int, nargs='+', help='Packet fields to check correctness') parser.add_argument( '--state-groups', type=int, nargs='+', help='State groups to check correctness') parser.add_argument( '--state-dependency', type=int, nargs='+', help='Show the dependency relation between different state groups\ the format will be 0 1 2 3 which means state_group0 happens\ before state_group1 and state_group2 happens before\ state_group3') parser.add_argument( '--input-packet', type=int, nargs='+', help='Show the actual packet fields used as the input of spec.') parser.add_argument( '-p', '--parallel', action='store_true', help='Whether to run multiple smaller sketches in parallel by\ setting salu_config variables explicitly.') parser.add_argument( '--parallel-sketch', action='store_true', help='Whether sketch process internally uses parallelism') parser.add_argument( '--hole-elimination', action='store_true', help='If set, add addtional assert statements to sketch, so that we \ would not see the same combination of hole value assignments.' ) parser.add_argument( '--synthesized-allocation', action='store_true', help='If set let sketch allocate state variables otherwise \ use canonical allocation, i.e, first state variable assigned \ to first phv container.' ) args = parser.parse_args(argv[1:]) if args.state_dependency is not None: assert len( args.state_dependency) % 2 == 0, 'dependency list len must be even' # Use program_content to store the program file text rather than using it # twice program_content = Path(args.program_file).read_text() num_fields_in_prog = get_num_pkt_fields(program_content) # Get the state vars information # TODO: add the max_input_bit into sketch_name state_group_info = get_state_group_info(program_content) # Get how many members in each state group group_size = len(list(state_group_info.items())[0][1]) sketch_name = args.program_file.split('/')[-1].split('.')[0] + \ '_' + args.stateful_alu_file.split('/')[-1].split('.')[0] + \ '_' + args.stateless_alu_file.split('/')[-1].split('.')[0] + \ '_' + str(args.num_pipeline_stages) + \ '_' + str(args.num_alus_per_stage) # Use OrderedSet here for deterministic compilation results. We can also # use built-in dict() for Python versions 3.6 and later, as it's inherently # ordered. constant_set = OrderedSet(args.constant_set.split(',')) compiler = Compiler(args.program_file, args.stateful_alu_file, args.stateless_alu_file, args.num_pipeline_stages, args.num_alus_per_stage, sketch_name, args.parallel_sketch, constant_set, args.synthesized_allocation, args.pkt_fields, args.state_groups, args.state_dependency, args.input_packet) # Repeatedly run synthesis at 2 bits and verification using all valid ints # until either verification succeeds or synthesis fails at 2 bits. Note # that the verification with all ints, might not work because sketch only # considers positive integers. # Synthesis is much faster at a smaller bit width, while verification needs # to run at a larger bit width for soundness. count = 1 hole_elimination_assert = [] additional_testcases = '' sol_verify_bit = args.max_input_bit while 1: print('Iteration #' + str(count)) (synthesis_ret_code, output, hole_assignments) = \ compiler.parallel_codegen( additional_constraints=hole_elimination_assert, additional_testcases=additional_testcases) \ if args.parallel else \ compiler.serial_codegen( iter_cnt=count, additional_constraints=hole_elimination_assert, additional_testcases=additional_testcases) if synthesis_ret_code != 0: compilation_failure(sketch_name, output) return 1 print('Synthesis succeeded with 2 bits, proceeding to verification.') pkt_fields, state_vars = compiler.verify( hole_assignments, sol_verify_bit, iter_cnt=count ) if len(pkt_fields) == 0 and len(state_vars) == 0: compilation_success(sketch_name, hole_assignments, output) return 0 print('Verification failed.') # NOTE(taegyunkim): There is no harm in using both hole elimination # asserts and counterexamples. We want to compare using only hole # elimination asserts and only counterexamples. if args.hole_elimination: hole_elimination_assert += generate_hole_elimination_assert( hole_assignments) print(hole_elimination_assert) else: print('Use returned counterexamples', pkt_fields, state_vars) # compiler.constant_set will be in the form "0,1,2,3" # Get the value of counterexample and add them into constant_set for _, value in pkt_fields.items(): value_str = str(value) constant_set.add(value_str) for _, value in state_vars.items(): value_str = str(value) constant_set.add(value_str) # Print the updated constant_array just for debugging print('updated constant array', constant_set) # Add constant set to compiler for next synthesis. compiler.update_constants_for_synthesis(constant_set) pkt_fields, state_vars = set_default_values( pkt_fields, state_vars, num_fields_in_prog, state_group_info ) additional_testcases += generate_counterexample_asserts( pkt_fields, state_vars, num_fields_in_prog, state_group_info, count, args.pkt_fields, args.state_groups, group_size) count += 1