示例#1
0
    def test_ignore_duplicates(self):
        state_group_info = get_state_group_info(
            """state_and_packet.state_group_0_state_0 = 1;
            state_and_packet.state_group_0_state_0 += 1;"""
        )

        self.assertDictEqual(state_group_info, {'0': set('0')})
示例#2
0
    def test_multiple_groups(self):
        state_group_info = get_state_group_info(
            """state_and_packet.state_group_0_state_0,
            state_and_packet.state_group_0_state_1,
            state_and_packet.state_group_1_state_0,
            state_and_packet.state_group_1_state_1""")

        expected_info = {
            '0': set(['0', '1']),
            '1': set(['0', '1']),
        }

        self.assertDictEqual(state_group_info, expected_info)
示例#3
0
    def __init__(self,
                 program_file,
                 stateful_alu_file,
                 stateless_alu_file,
                 num_pipeline_stages,
                 num_alus_per_stage,
                 sketch_name,
                 parallel_sketch,
                 constant_set,
                 synthesized_allocation=False,
                 pkt_fields_to_check=[],
                 state_groups_to_check=[],
                 state_dependency=[],
                 input_packet=[]):
        self.program_file = program_file
        self.stateful_alu_file = stateful_alu_file
        self.stateless_alu_file = stateless_alu_file
        self.num_pipeline_stages = num_pipeline_stages
        self.num_alus_per_stage = num_alus_per_stage
        self.sketch_name = sketch_name
        self.parallel_sketch = parallel_sketch
        self.constant_set = constant_set
        self.synthesized_allocation = synthesized_allocation

        program_content = Path(program_file).read_text()
        self.num_fields_in_prog = get_num_pkt_fields(program_content)
        self.num_state_groups = len(get_state_group_info(program_content))

        if not input_packet:
            assert self.num_fields_in_prog <= num_alus_per_stage, (
                'Number of fields in program %d is '
                'greater than number of alus per stage %d. Try increasing '
                'number of alus per stage.' %
                (self.num_fields_in_prog, num_alus_per_stage))
        else:
            assert len(input_packet) <= num_alus_per_stage, (
                'Number of input fields in program %d is'
                'greater than number of alus per stage %d. Try increasing '
                'number of alus per stage.' %
                (len(input_packet), num_alus_per_stage))
            # Guarantee that # of pkt_fields_to_check is less than or equal
            # to the num_alus_per_stage
            if pkt_fields_to_check is not None:
                assert len(pkt_fields_to_check) <= num_alus_per_stage, (
                    'Number of checked fields in program %d is '
                    'greater than number of alus per stage %d. '
                    'Try increasing number of alus per stage.' %
                    (len(pkt_fields_to_check), num_alus_per_stage))

        # Initialize jinja2 environment for templates
        self.jinja2_env = Environment(loader=FileSystemLoader([
            path.join(path.dirname(__file__), './templates'),
            path.join(os.getcwd(),
                      stateless_alu_file[:stateless_alu_file.rfind('/')]), '.',
            '/'
        ]),
                                      undefined=StrictUndefined,
                                      trim_blocks=True,
                                      lstrip_blocks=True)

        if not pkt_fields_to_check and not state_groups_to_check:
            pkt_fields_to_check = list(range(self.num_fields_in_prog))
            state_groups_to_check = list(range(self.num_state_groups))
        elif not pkt_fields_to_check and state_groups_to_check:
            pkt_fields_to_check = []
        elif pkt_fields_to_check and not state_groups_to_check:
            state_groups_to_check = []

        # Differentiate between using default pkt input vs. specify pkt input
        if not input_packet:
            input_packet = list(range(self.num_fields_in_prog))

        # Create an object for sketch generation
        self.sketch_generator = SketchGenerator(
            sketch_name=sketch_name,
            num_pipeline_stages=num_pipeline_stages,
            num_alus_per_stage=num_alus_per_stage,
            num_phv_containers=num_alus_per_stage,
            num_state_groups=self.num_state_groups,
            num_fields_in_prog=self.num_fields_in_prog,
            pkt_fields_to_check=pkt_fields_to_check,
            state_dependency=state_dependency,
            state_groups_to_check=state_groups_to_check,
            jinja2_env=self.jinja2_env,
            stateful_alu_file=stateful_alu_file,
            stateless_alu_file=stateless_alu_file,
            constant_set=constant_set,
            synthesized_allocation=synthesized_allocation,
            input_packet=input_packet)
示例#4
0
def main(argv):
    parser = argparse.ArgumentParser(description='Iterative solver.')
    parser.add_argument(
        'program_file', help='Program specification in .sk file')
    parser.add_argument('stateful_alu_file', help='Stateful ALU file to use.')
    parser.add_argument(
        'stateless_alu_file', help='Stateless ALU file to use.')
    parser.add_argument(
        'num_pipeline_stages', type=int, help='Number of pipeline stages')
    parser.add_argument(
        'num_alus_per_stage',
        type=int,
        help='Number of stateless/stateful ALUs per stage')
    parser.add_argument(
        'constant_set',
        type=str,
        help='The content in the constant_set\
              and the format will be like 0,1,2,3\
              and we will calculate the number of\
              comma to get the size of it')
    parser.add_argument(
        'max_input_bit',
        type=int,
        help='The maximum input value in bits')
    parser.add_argument(
        '--pkt-fields',
        type=int,
        nargs='+',
        help='Packet fields to check correctness')
    parser.add_argument(
        '--state-groups',
        type=int,
        nargs='+',
        help='State groups to check correctness')
    parser.add_argument(
        '--state-dependency',
        type=int,
        nargs='+',
        help='Show the dependency relation between different state groups\
              the format will be 0 1 2 3 which means state_group0 happens\
              before state_group1 and state_group2 happens before\
              state_group3')
    parser.add_argument(
        '--input-packet',
        type=int,
        nargs='+',
        help='Show the actual packet fields used as the input of spec.')
    parser.add_argument(
        '-p',
        '--parallel',
        action='store_true',
        help='Whether to run multiple smaller sketches in parallel by\
              setting salu_config variables explicitly.')
    parser.add_argument(
        '--parallel-sketch',
        action='store_true',
        help='Whether sketch process internally uses parallelism')
    parser.add_argument(
        '--hole-elimination',
        action='store_true',
        help='If set, add addtional assert statements to sketch, so that we \
              would not see the same combination of hole value assignments.'
    )
    parser.add_argument(
        '--synthesized-allocation',
        action='store_true',
        help='If set let sketch allocate state variables otherwise \
              use canonical allocation, i.e, first state variable assigned \
              to first phv container.'
    )

    args = parser.parse_args(argv[1:])
    if args.state_dependency is not None:
        assert len(
            args.state_dependency) % 2 == 0, 'dependency list len must be even'
    # Use program_content to store the program file text rather than using it
    # twice
    program_content = Path(args.program_file).read_text()
    num_fields_in_prog = get_num_pkt_fields(program_content)

    # Get the state vars information
    # TODO: add the max_input_bit into sketch_name
    state_group_info = get_state_group_info(program_content)

    # Get how many members in each state group
    group_size = len(list(state_group_info.items())[0][1])

    sketch_name = args.program_file.split('/')[-1].split('.')[0] + \
        '_' + args.stateful_alu_file.split('/')[-1].split('.')[0] + \
        '_' + args.stateless_alu_file.split('/')[-1].split('.')[0] + \
        '_' + str(args.num_pipeline_stages) + \
        '_' + str(args.num_alus_per_stage)

    # Use OrderedSet here for deterministic compilation results. We can also
    # use built-in dict() for Python versions 3.6 and later, as it's inherently
    # ordered.
    constant_set = OrderedSet(args.constant_set.split(','))

    compiler = Compiler(args.program_file, args.stateful_alu_file,
                        args.stateless_alu_file,
                        args.num_pipeline_stages, args.num_alus_per_stage,
                        sketch_name, args.parallel_sketch,
                        constant_set,
                        args.synthesized_allocation, args.pkt_fields,
                        args.state_groups, args.state_dependency,
                        args.input_packet)
    # Repeatedly run synthesis at 2 bits and verification using all valid ints
    # until either verification succeeds or synthesis fails at 2 bits. Note
    # that the verification with all ints, might not work because sketch only
    # considers positive integers.
    # Synthesis is much faster at a smaller bit width, while verification needs
    # to run at a larger bit width for soundness.
    count = 1
    hole_elimination_assert = []
    additional_testcases = ''
    sol_verify_bit = args.max_input_bit
    while 1:
        print('Iteration #' + str(count))
        (synthesis_ret_code, output, hole_assignments) = \
            compiler.parallel_codegen(
                additional_constraints=hole_elimination_assert,
                additional_testcases=additional_testcases) \
            if args.parallel else \
            compiler.serial_codegen(
            iter_cnt=count,
            additional_constraints=hole_elimination_assert,
            additional_testcases=additional_testcases)

        if synthesis_ret_code != 0:
            compilation_failure(sketch_name, output)
            return 1

        print('Synthesis succeeded with 2 bits, proceeding to verification.')
        pkt_fields, state_vars = compiler.verify(
            hole_assignments, sol_verify_bit, iter_cnt=count
        )

        if len(pkt_fields) == 0 and len(state_vars) == 0:
            compilation_success(sketch_name, hole_assignments, output)
            return 0

        print('Verification failed.')

        # NOTE(taegyunkim): There is no harm in using both hole elimination
        # asserts and counterexamples. We want to compare using only hole
        # elimination asserts and only counterexamples.
        if args.hole_elimination:
            hole_elimination_assert += generate_hole_elimination_assert(
                hole_assignments)
            print(hole_elimination_assert)
        else:
            print('Use returned counterexamples', pkt_fields, state_vars)

            # compiler.constant_set will be in the form "0,1,2,3"

            # Get the value of counterexample and add them into constant_set
            for _, value in pkt_fields.items():
                value_str = str(value)
                constant_set.add(value_str)
            for _, value in state_vars.items():
                value_str = str(value)
                constant_set.add(value_str)

            # Print the updated constant_array just for debugging
            print('updated constant array', constant_set)

            # Add constant set to compiler for next synthesis.
            compiler.update_constants_for_synthesis(constant_set)

            pkt_fields, state_vars = set_default_values(
                pkt_fields, state_vars, num_fields_in_prog, state_group_info
            )

            additional_testcases += generate_counterexample_asserts(
                pkt_fields, state_vars, num_fields_in_prog, state_group_info,
                count, args.pkt_fields, args.state_groups, group_size)

        count += 1