示例#1
0
    def __init__(self, qid, operator_id, keys, filter_keys, func, source, match_action, miss_action, p4_raw_fields):
        super(P4Filter, self).__init__('Filter', qid, operator_id, keys, p4_raw_fields)

        self.filter_keys = filter_keys
        self.filter_mask = None
        self.filter_values = None
        self.func = None
        # self.out_headers = []
        self.match_action = match_action
        self.miss_action = miss_action

        self.source = source

        if not len(func) > 0 or func[0] == 'geq':
            self.logger.error('Got the following func with the Filter Operator: %s' % (str(func),))
            # raise NotImplementedError
        else:
            self.func = func[0]
            if func[0] == 'mask':
                self.filter_mask = func[1]
                self.filter_values = func[2:]
            elif func[0] == 'eq':
                self.filter_values = [func[1:]]

        reads_fields = list()
        for filter_key in self.filter_keys:
            print "Filter key: ", filter_key, self.func
            print self.operator_specific_fields

            if self.func == 'mask':
                reads_fields.append((filter_key, 'lpm'))
            else:
                reads_fields.append((filter_key, 'exact'))
        print "Debug P4Filter", self.operator_name, miss_action, (match_action,), reads_fields, TABLE_SIZE
        self.table = Table(self.operator_name, miss_action, (match_action,), reads_fields, TABLE_SIZE)
示例#2
0
    def __init__(self, qid, operator_id, meta_init_name, keys, map_keys, func, p4_raw_fields):
        super(P4Map, self).__init__('Map', qid, operator_id, keys, p4_raw_fields)

        self.meta_init_name = meta_init_name
        self.map_keys = map_keys
        self.func = func

        # Add map init
        map_fields = list()
        for fld in self.map_keys:
            if fld == 'qid':
                map_fields.append(P4Field(layer=None, target_name="qid", sonata_name="qid",
                                          size=QID_SIZE))
            elif fld == 'count':
                map_fields.append(P4Field(layer=None, target_name="count", sonata_name="count",
                                          size=COUNT_SIZE))
            else:
                map_fields.append(self.p4_raw_fields.get_target_field(fld))
        # create ACTION using the function
        primitives = list()
        if len(func) > 0:
            self.func = func
            if func[0] == 'mask' or not func[0]:
                for field in map_fields:
                    # print self.__repr__(), self.map_keys
                    mask_size = (func[1] / 4)
                    mask = '0x' + ('f' * mask_size) + ('0' * (HEADER_MASK_SIZE[field.target_name] - mask_size))
                    field_name = '%s.%s' % (self.meta_init_name, field.target_name.replace(".", "_"))
                    primitives.append(BitAnd(field_name, field_name, mask))

        self.action = Action('do_%s' % self.operator_name, primitives)

        # create dummy TABLE to execute the action
        self.table = Table(self.operator_name, self.action.get_name(), [], None, 1)
示例#3
0
class P4Map(P4Operator):
    def __init__(self, qid, operator_id, meta_init_name, keys, map_keys, func, p4_raw_fields):
        super(P4Map, self).__init__('Map', qid, operator_id, keys, p4_raw_fields)

        self.meta_init_name = meta_init_name
        self.map_keys = map_keys
        self.func = func

        # Add map init
        map_fields = list()
        for fld in self.map_keys:
            if fld == 'qid':
                map_fields.append(P4Field(layer=None, target_name="qid", sonata_name="qid",
                                          size=QID_SIZE))
            elif fld == 'count':
                map_fields.append(P4Field(layer=None, target_name="count", sonata_name="count",
                                          size=COUNT_SIZE))
            else:
                map_fields.append(self.p4_raw_fields.get_target_field(fld))
        # create ACTION using the function
        primitives = list()
        if len(func) > 0:
            self.func = func
            if func[0] == 'mask' or not func[0]:
                for field in map_fields:
                    # print self.__repr__(), self.map_keys
                    mask_size = (func[1] / 4)
                    mask = '0x' + ('f' * mask_size) + ('0' * (HEADER_MASK_SIZE[field.target_name] - mask_size))
                    field_name = '%s.%s' % (self.meta_init_name, field.target_name.replace(".", "_"))
                    primitives.append(BitAnd(field_name, field_name, mask))

        self.action = Action('do_%s' % self.operator_name, primitives)

        # create dummy TABLE to execute the action
        self.table = Table(self.operator_name, self.action.get_name(), [], None, 1)

    def __repr__(self):
        return '.Map(keys=' + str(self.keys) + ', map_keys=' + str(self.map_keys) + ', func=' + str(self.func) + ')'

    def get_code(self):
        out = ''
        out += '// Map %i of query %i\n' % (self.operator_id, self.query_id)
        out += self.action.get_code()
        out += self.table.get_code()
        out += '\n'
        return out

    def get_commands(self):
        commands = list()
        commands.append(self.table.get_default_command())
        return commands

    def get_control_flow(self, indent_level):
        indent = '\t' * indent_level
        out = ''
        out += '%sapply(%s);\n' % (indent, self.table.get_name())
        return out

    def get_init_keys(self):
        return self.keys
    def __init__(self, qid, operator_id, keys, p4_raw_fields):
        super(P4MapInit, self).__init__('MapInit', qid, operator_id, keys,
                                        p4_raw_fields)

        # Add map init
        map_init_fields = list()
        for fld in self.keys:
            if fld == 'qid':
                map_init_fields.append(
                    P4Field(layer=None,
                            target_name="qid",
                            sonata_name="qid",
                            size=QID_SIZE))
            elif fld == 'count':
                map_init_fields.append(
                    P4Field(layer=None,
                            target_name="count",
                            sonata_name="count",
                            size=COUNT_SIZE))
            elif fld == 'index':
                map_init_fields.append(
                    P4Field(layer=None,
                            target_name="index",
                            sonata_name="index",
                            size=INDEX_SIZE))
            else:
                map_init_fields.append(
                    self.p4_raw_fields.get_target_field(fld))
        # create METADATA object to store data for all keys
        meta_fields = list()
        for fld in map_init_fields:
            meta_fields.append((fld.target_name.replace(".", "_"), fld.size))

        self.metadata = MetaData(self.operator_name, meta_fields)

        # create ACTION to initialize the metadata
        primitives = list()
        for fld in map_init_fields:
            sonata_name = fld.sonata_name
            target_name = fld.target_name
            meta_field_name = '%s.%s' % (self.metadata.get_name(),
                                         target_name.replace(".", "_"))

            if sonata_name == 'qid':
                # Assign query id to this field
                primitives.append(ModifyField(meta_field_name, qid))
            elif sonata_name == 'count':
                primitives.append(ModifyField(meta_field_name, 0))
            elif sonata_name == 'index':
                primitives.append(ModifyField(meta_field_name, 0))
            else:
                # Read data from raw header fields and assign them to these meta fields
                primitives.append(ModifyField(meta_field_name, target_name))

        self.action = Action('do_%s' % self.operator_name, primitives)

        # create dummy TABLE to execute the action
        self.table = Table(self.operator_name, self.action.get_name(), [],
                           None, 1)
示例#5
0
 def mark_satisfied(self):
     primitives = list()
     primitives.append(ModifyField(self.satisfied_meta_field, 1))
     primitives.append(ModifyField(self.clone_meta_field, 1))
     self.actions['satisfied'] = Action('do_mark_satisfied_%i' % self.id,
                                        primitives)
     self.satisfied_table = Table('mark_satisfied_%i' % self.id,
                                  self.actions['satisfied'].get_name(), [],
                                  None, 1)
示例#6
0
 def append_out_header(self):
     primitives = list()
     primitives.append(AddHeader(self.out_header.get_name()))
     for fld in self.out_header.fields:
         primitives.append(ModifyField('%s.%s' % (self.out_header.get_name(), fld.target_name.replace(".", "_")),
                                       '%s.%s' % (self.meta_init_name, fld.target_name.replace(".", "_"))))
     self.actions['append_out_header'] = Action('do_add_out_header_%i' % self.id, primitives)
     self.out_header_table = Table('add_out_header_%i' % self.id, self.actions['append_out_header'].get_name(), [],
                                   None, 1)
示例#7
0
    def init_application(self, app):
        queries = dict()

        # define final header
        # TODO: Use new p4 layer object
        tmp = OutHeaders("final_header")
        tmp.fields = [P4Field(tmp, "delimiter", "delimiter", 32)]
        self.final_header = tmp

        primitives = list()
        primitives.append(AddHeader(self.final_header.get_name()))
        primitives.append(ModifyField('%s.delimiter' % self.final_header.get_name(), 0))
        self.final_header_action = Action('do_add_final_header', primitives)

        self.final_header_table = Table('add_final_header', self.final_header_action.get_name(), [], None, 1)

        # define nop action
        self.nop_action = Action('_nop', NoOp())
        nop_name = self.nop_action.get_name()

        # app metadata
        fields = list()
        for query_id in app:
            fields.append(('%s_%i' % (self.drop_meta_field, query_id), 1))
            fields.append(('%s_%i' % (self.satisfied_meta_field, query_id), 1))
        fields.append((self.clone_meta_field, 1))
        self.metadata = MetaData('app_data', fields)
        meta_name = self.metadata.get_name()

        # action and table to init app metadata
        primitives = list()
        for field_name, _ in fields:
            primitives.append(ModifyField('%s.%s' % (self.metadata.get_name(), field_name), 0))
        self.init_action = Action('do_init_app_metadata', primitives)

        self.init_action_table = Table('init_app_metadata', self.init_action.get_name(), [], None, 1)

        # transforms queries
        for query_id in app:
            self.logger.debug('create query pipeline for qid: %i' % (query_id))
            parse_payload = app[query_id].parse_payload
            payload_fields = app[query_id].payload_fields
            read_register = app[query_id].read_register
            filter_payload = app[query_id].filter_payload
            filter_payload_str = app[query_id].filter_payload_str
            operators = app[query_id].operators
            query = P4Query(query_id,
                            parse_payload,
                            payload_fields,
                            read_register,
                            filter_payload,
                            filter_payload_str,
                            operators,
                            nop_name,
                            '%s.%s' % (meta_name, self.drop_meta_field),
                            '%s.%s' % (meta_name, self.satisfied_meta_field),
                            '%s.%s' % (meta_name, self.clone_meta_field), self.p4_raw_fields)
            queries[query_id] = query

        # define mirroring session
        self.mirror_session = MirrorSession(SESSION_ID, SPAN_PORT)

        # define report action that clones the packet and sends it to the stream processor
        fields = [self.metadata.get_name()]
        for query in queries.values():
            fields.append(query.get_metadata_name())
        self.field_list = FieldList('report_packet', fields)

        self.report_action = Action('do_report_packet', CloneIngressPktToEgress(self.mirror_session.get_session_id(),
                                                                                self.field_list.get_name()))

        self.report_action_table = Table('report_packet', self.report_action.get_name(), [], None, 1)
        return queries
示例#8
0
class P4Application(object):
    def __init__(self, app, sonata_fields):
        # LOGGING
        log_level = logging.DEBUG
        self.logger = get_logger('P4Application', 'INFO')
        self.logger.setLevel(log_level)
        self.logger.info('init')

        # Define the root layer for raw packet
        # self.root_layer = Ethernet()
        # self.p4_raw_fields = P4RawFields(self.root_layer)

        self.root_layer = "ethernet"
        self.p4_raw_fields = sonata_fields

        # define the application metadata
        self.drop_meta_field = 'drop'
        self.satisfied_meta_field = 'satisfied'
        self.clone_meta_field = 'clone'

        self.mirror_session = None
        self.field_list = None

        self.final_header = None
        self.final_header_action = None
        self.final_header_table = None

        self.init_action = None
        self.init_action_table = None

        self.report_action = None
        self.report_action_table = None
        self.nop_action = None
        self.metadata = None
        self.queries = self.init_application(app)

    # INIT THE DATASTRUCTURE
    def init_application(self, app):
        queries = dict()

        # define final header
        # TODO: Use new p4 layer object
        tmp = OutHeaders("final_header")
        tmp.fields = [P4Field(tmp, "delimiter", "delimiter", 32)]
        self.final_header = tmp

        primitives = list()
        primitives.append(AddHeader(self.final_header.get_name()))
        primitives.append(ModifyField('%s.delimiter' % self.final_header.get_name(), 0))
        self.final_header_action = Action('do_add_final_header', primitives)

        self.final_header_table = Table('add_final_header', self.final_header_action.get_name(), [], None, 1)

        # define nop action
        self.nop_action = Action('_nop', NoOp())
        nop_name = self.nop_action.get_name()

        # app metadata
        fields = list()
        for query_id in app:
            fields.append(('%s_%i' % (self.drop_meta_field, query_id), 1))
            fields.append(('%s_%i' % (self.satisfied_meta_field, query_id), 1))
        fields.append((self.clone_meta_field, 1))
        self.metadata = MetaData('app_data', fields)
        meta_name = self.metadata.get_name()

        # action and table to init app metadata
        primitives = list()
        for field_name, _ in fields:
            primitives.append(ModifyField('%s.%s' % (self.metadata.get_name(), field_name), 0))
        self.init_action = Action('do_init_app_metadata', primitives)

        self.init_action_table = Table('init_app_metadata', self.init_action.get_name(), [], None, 1)

        # transforms queries
        for query_id in app:
            self.logger.debug('create query pipeline for qid: %i' % (query_id))
            parse_payload = app[query_id].parse_payload
            payload_fields = app[query_id].payload_fields
            read_register = app[query_id].read_register
            filter_payload = app[query_id].filter_payload
            filter_payload_str = app[query_id].filter_payload_str
            operators = app[query_id].operators
            query = P4Query(query_id,
                            parse_payload,
                            payload_fields,
                            read_register,
                            filter_payload,
                            filter_payload_str,
                            operators,
                            nop_name,
                            '%s.%s' % (meta_name, self.drop_meta_field),
                            '%s.%s' % (meta_name, self.satisfied_meta_field),
                            '%s.%s' % (meta_name, self.clone_meta_field), self.p4_raw_fields)
            queries[query_id] = query

        # define mirroring session
        self.mirror_session = MirrorSession(SESSION_ID, SPAN_PORT)

        # define report action that clones the packet and sends it to the stream processor
        fields = [self.metadata.get_name()]
        for query in queries.values():
            fields.append(query.get_metadata_name())
        self.field_list = FieldList('report_packet', fields)

        self.report_action = Action('do_report_packet', CloneIngressPktToEgress(self.mirror_session.get_session_id(),
                                                                                self.field_list.get_name()))

        self.report_action_table = Table('report_packet', self.report_action.get_name(), [], None, 1)
        return queries

    # COMPILE THE CODE
    def get_p4_code(self):
        p4_src = ''

        # Get parser for raw headers (layers) that are specific to the fields used in Sonata queries
        p4_src += self.get_raw_parser_code()

        # P4 INVARIANTS
        p4_src += self.get_invariants()

        # OUT HEADER PARSER
        p4_src += self.get_out_header_parser()

        # APP METADATA, ACTIONS, FIELDLISTS, TABLES
        p4_src += self.get_app_code()

        # QUERY METADATA, HEADERS, TABLES AND ACTIONS
        p4_src += self.get_code()

        # get original packet repeat code
        if ORIGINAL_PACKET: p4_src += self.get_original_repeat_code()

        # INGRESS PIPELINE
        p4_src += self.get_ingress_pipeline()

        # EGRESS PIPELINE
        if ORIGINAL_PACKET:
            p4_src += "control egress { }"
        else:
            p4_src += self.get_egress_pipeline()

        return p4_src

    def get_invariants(self):
        # Call this from respective layer classes
        out = ''
        out += 'parser start {\n'
        out += '\treturn select(current(0, 64)) {\n'
        out += '\t\t0 : parse_out_header;\n'
        out += '\t\tdefault: parse_'+self.root_layer+';\n'
        out += '\t}\n'
        out += '}\n\n'
        return out

    def get_raw_parser_code(self):
        raw_layers = self.get_raw_layers()
        out = ""
        for layer in raw_layers:
            p4_layer = get_p4_layer(layer)
            out += p4_layer.get_header_specification_code()
            out += p4_layer.get_parser_code(raw_layers)

        return out

    def get_raw_layers(self):
        raw_fields = set()
        for qid in self.queries:
            # all_fields.union(set(operator.get_init_keys()))
            raw_fields = raw_fields.union(self.queries[qid].all_fields)

        # TODO: get rid of this local fix. This won't be required after we fix the sonata query module
        # Start local fix
        raw_fields = [x for x in raw_fields]
        # End local fix

        raw_layers = self.p4_raw_fields.get_layers_for_fields(raw_fields)
        return raw_layers

    def get_out_header_parser(self):
        # This needs to be called from the header class itself
        out = ''
        out += 'parser parse_out_header {\n'
        for query in self.queries.values():
            out += '\textract(%s);\n' % query.out_header.get_name()
        out += '\t%s\n' % self.final_header.get_parser_code()
        out += '\treturn parse_ethernet;\n'
        out += '}\n\n'
        return out

    def get_app_code(self):
        out = ''
        out += self.init_action.get_code()
        out += self.init_action_table.get_code()
        out += self.metadata.get_code()
        out += self.nop_action.get_code()
        out += self.field_list.get_code()
        out += self.report_action.get_code()
        out += self.report_action_table.get_code()
        out += self.final_header.get_header_specification_code()
        out += self.final_header_action.get_code()
        out += self.final_header_table.get_code()
        return out

    def get_code(self):
        out = ''
        for query in self.queries.values():
            out += query.get_code()
        return out

    def get_original_repeat_code(self):
        original_repeat = """
action _drop() {
	drop();
}

action repeat(dport) {
    modify_field(standard_metadata.egress_spec, dport);
}

table forward {
    reads {
        standard_metadata.ingress_port: exact;
    }
    actions {
        repeat;
        _drop;
    }
    size: 2;
}\n"""
        return original_repeat

    def get_ingress_pipeline(self):
        out = ''
        out += 'control ingress {\n'
        out += '\tapply(%s);\n' % self.init_action_table.get_name()

        # add the control flow of one query after the other
        for query in self.queries.values():
            out += query.get_ingress_control_flow(2)

        out += '\n'

        # after processing all queries, determine whether the packet should be sent to the emitter as it satisfied at
        # least one query
        out += '\tif (%s.%s == 1) {\n' % (self.metadata.get_name(), self.clone_meta_field)
        out += '\t\tapply(%s);\n' % self.report_action_table.get_name()
        out += '\t}\n'

        if ORIGINAL_PACKET: out += '\tapply(forward);\n'

        out += '}\n\n'
        return out

    def get_egress_pipeline(self):
        out = ''
        out += 'control egress {\n'
        # normal forwarding of the original packet
        out += '\tif (standard_metadata.instance_type == 0) {\n'
        out += '\t\t// original packet, apply forwarding\n'
        out += '\t}\n\n'

        # adding header to the report packet which is sent to the emitter
        out += '\telse if (standard_metadata.instance_type == 1) {\n'
        for query in self.queries.values():
            out += query.get_egress_control_flow(2)
        out += '\t\tapply(%s);\n' % self.final_header_table.get_name()
        out += '\t}\n'
        out += '}\n\n'
        return out

    def get_commands(self):
        commands = list()
        for query in self.queries.values():
            commands += query.get_commands()
        commands.append(self.report_action_table.get_default_command())
        commands.append(self.final_header_table.get_default_command())
        commands.append(self.mirror_session.get_command())
        if ORIGINAL_PACKET:
            commands.append("table_set_default forward _drop")
            commands.append("table_add forward repeat %s => %s" % (SENDER_PORT, RECIEVE_PORT))
            commands.append("table_add forward repeat %s => %s" % (RECIEVE_PORT, SENDER_PORT))

        return commands

    # def get_header_format(self):
    #     header_format = dict()
    #     header_format['parse_payload'] = self.parse_payload
    #     header_format['headers'] = self.out_header_fields
    #     return header_format

    def get_header_formats(self):
        # This needs updates as we now change the logic of packet parsing at the emitter
        header_formats = dict()
        for qid, query in self.queries.iteritems():
            header_formats[qid] = query.get_header_format()
        return header_formats

    def get_update_commands(self, filter_update):
        commands = list()
        for qid, filter_id in filter_update:
            commands.extend(self.queries[qid].get_update_commands(filter_id, filter_update[(qid, filter_id)]))
        return commands
示例#9
0
class P4Query(object):
    all_fields = []
    out_header = None
    out_header_table = None
    query_drop_action = None
    satisfied_table = None

    def __init__(self, query_id, parse_payload, payload_fields, read_register,
                 filter_payload, filter_payload_str, generic_operators,
                 nop_name, drop_meta_field, satisfied_meta_field,
                 clone_meta_field, p4_raw_fields):
        # LOGGING
        log_level = logging.ERROR
        self.logger = get_logger('P4Query - %i' % query_id, 'DEBUG')
        self.logger.setLevel(log_level)
        self.logger.info('init')
        self.id = query_id
        self.parse_payload = parse_payload
        self.payload_fields = payload_fields
        self.read_register = read_register
        self.filter_payload = filter_payload
        self.filter_payload_str = filter_payload_str
        self.registers_to_read = []
        self.meta_init_name = ''
        # print '$$$$$$$$$$$$$ vals: ' + str(self.parse_payload) + ":" + str(self.read_register)

        self.src_to_filter_operator = dict()

        self.nop_action = nop_name

        self.drop_meta_field = '%s_%i' % (drop_meta_field, self.id)
        self.satisfied_meta_field = '%s_%i' % (satisfied_meta_field, self.id)
        self.clone_meta_field = clone_meta_field

        self.p4_raw_fields = p4_raw_fields

        self.actions = dict()

        # general drop action which is applied when a packet doesn't satisfy this query
        self.add_general_drop_action()

        # action and table to mark query as satisfied at end of query processing in ingress
        self.mark_satisfied()

        # initialize operators
        self.get_all_fields(generic_operators)

        self.operators = self.init_operators(generic_operators)

        # create an out header layer
        self.create_out_header()

        # action and table to populate out_header in egress
        self.append_out_header()

    def mark_satisfied(self):
        primitives = list()
        primitives.append(ModifyField(self.satisfied_meta_field, 1))
        primitives.append(ModifyField(self.clone_meta_field, 1))
        self.actions['satisfied'] = Action('do_mark_satisfied_%i' % self.id,
                                           primitives)
        self.satisfied_table = Table('mark_satisfied_%i' % self.id,
                                     self.actions['satisfied'].get_name(), [],
                                     None, 1)

    def add_general_drop_action(self):
        self.actions['drop'] = Action('drop_%i' % self.id,
                                      (ModifyField(self.drop_meta_field, 1)))
        self.query_drop_action = self.actions['drop'].get_name()

    def create_out_header(self):
        out_header_name = 'out_header_%i' % self.id
        self.out_header = OutHeaders(out_header_name)

        sonata_field_list = filter(
            lambda x: x not in self.payload_fields + ['ts'],
            self.operators[-1].get_out_headers())

        sonata_field_list = ['qid'] + sonata_field_list

        out_header_fields = list()

        for fld in sonata_field_list:
            if fld == 'qid':
                out_header_fields.append(
                    P4Field(layer=self.out_header,
                            target_name="qid",
                            sonata_name="qid",
                            size=QID_SIZE))
            elif fld == 'count':
                out_header_fields.append(
                    P4Field(layer=self.out_header,
                            target_name="count",
                            sonata_name="count",
                            size=COUNT_SIZE))
            elif fld == 'index':
                out_header_fields.append(
                    P4Field(layer=self.out_header,
                            target_name="index",
                            sonata_name="index",
                            size=INDEX_SIZE))
            else:
                out_header_fields.append(
                    self.p4_raw_fields.get_target_field(fld))

        for operator in self.operators:
            if operator.name == 'Reduce':
                self.registers_to_read.append(operator.register.name)

        # Add fields to this out header
        self.out_header.fields = out_header_fields

    def append_out_header(self):
        primitives = list()
        primitives.append(AddHeader(self.out_header.get_name()))
        for fld in self.out_header.fields:
            primitives.append(
                ModifyField(
                    '%s.%s' % (self.out_header.get_name(),
                               fld.target_name.replace(".", "_")), '%s.%s' %
                    (self.meta_init_name, fld.target_name.replace(".", "_"))))
        self.actions['append_out_header'] = Action(
            'do_add_out_header_%i' % self.id, primitives)
        self.out_header_table = Table(
            'add_out_header_%i' % self.id,
            self.actions['append_out_header'].get_name(), [], None, 1)

    def get_all_fields(self, generic_operators):
        # TODO: only select fields over which we perform any action
        all_fields = set()
        for operator in generic_operators:
            if operator.name in {'Filter', 'Map', 'Reduce', 'Distinct'}:
                all_fields = all_fields.union(set(operator.get_init_keys()))
        # print "get_all_fields1: ", all_fields
        # TODO remove this
        self.all_fields = filter(
            lambda x: x not in self.payload_fields + ['ts'], all_fields)

    def get_init_fields(self, generic_operators):
        # TODO: only select fields over which we perform any action
        all_fields = set()
        for operator in generic_operators:
            if operator.name in {'Map', 'Reduce', 'Distinct', 'Filter'}:
                all_fields = all_fields.union(set(operator.get_init_keys()))
        # No need to filter out count field
        # print "get_all_fields: ", all_fields
        return filter(lambda x: x not in self.payload_fields + ['ts'],
                      all_fields)

    def init_operators(self, generic_operators):
        p4_operators = list()
        operator_id = 1

        map_init_keys = ['qid'] + self.get_init_fields(generic_operators)

        if self.read_register: map_init_keys += ['index']

        self.logger.debug('add map_init with keys: %s' %
                          (', '.join(map_init_keys), ))
        map_init_operator = P4MapInit(self.id, operator_id, map_init_keys,
                                      self.p4_raw_fields)
        self.meta_init_name = map_init_operator.get_meta_name()
        p4_operators.append(map_init_operator)

        # add all the other operators one after the other
        for operator in generic_operators:
            self.logger.debug('add %s operator' % (operator.name, ))
            operator_id += 1

            # TODO: Confirm if this is the right way
            keys = filter(lambda x: x != 'payload' and x != 'ts',
                          operator.keys)
            operator.keys = keys
            # TODO: Confirm if this is the right way

            if operator.name == 'Filter':
                match_action = self.nop_action
                miss_action = self.query_drop_action
                filter_operator = P4Filter(self.id, operator_id, operator.keys,
                                           operator.filter_keys, operator.func,
                                           operator.src, match_action,
                                           miss_action, self.p4_raw_fields)
                if operator.src != 0:
                    self.src_to_filter_operator[operator.src] = filter_operator
                p4_operators.append(filter_operator)

            elif operator.name == 'Map':
                p4_operators.append(
                    P4Map(self.id, operator_id, self.meta_init_name,
                          operator.keys, operator.map_keys,
                          operator.map_values, operator.func,
                          self.p4_raw_fields))

            elif operator.name == 'Reduce':
                p4_operators.append(
                    P4Reduce(self.id, operator_id, self.meta_init_name,
                             self.query_drop_action, operator.keys,
                             operator.values, operator.threshold,
                             self.read_register, self.p4_raw_fields))

            elif operator.name == 'Distinct':
                p4_operators.append(
                    P4Distinct(self.id, operator_id, self.meta_init_name,
                               self.query_drop_action, self.nop_action,
                               operator.keys, self.p4_raw_fields))

            else:
                self.logger.error('tried to add an unsupported operator: %s' %
                                  operator.name)
        return p4_operators

    def get_ingress_control_flow(self, indent_level):
        curr_indent_level = indent_level

        indent = '\t' * curr_indent_level
        out = '%s// query %i\n' % (indent, self.id)
        # apply one operator after another
        for operator in self.operators:
            indent = '\t' * curr_indent_level
            curr_indent_level += 1
            out += '%sif (%s != 1) {\n' % (indent, self.drop_meta_field)
            out += operator.get_control_flow(curr_indent_level)

        # mark packet as satisfied if it has never been marked as dropped
        indent = '\t' * curr_indent_level
        out += '%sif (%s != 1) {\n' % (indent, self.drop_meta_field)
        out += '%s\tapply(%s);\n' % (indent, self.satisfied_table.get_name())
        out += '%s}\n' % indent

        # close brackets
        for _ in self.operators:
            curr_indent_level -= 1
            indent = '\t' * curr_indent_level
            out += '%s}\n' % indent

        return out

    def get_egress_control_flow(self, indent_level):
        indent = '\t' * indent_level

        out = '%sif (%s == 1) {\n' % (indent, self.satisfied_meta_field)
        out += '%s\tapply(%s);\n' % (indent, self.out_header_table.get_name())
        out += '%s}\n' % indent
        return out

    def get_code(self):
        out = '// query %i\n' % self.id

        # out header
        out += self.out_header.get_header_specification_code()

        # query actions (drop, mark satisfied, add out header, etc)
        for action in self.actions.values():
            out += action.get_code()

        # query tables (add out header, mark satisfied)
        out += self.out_header_table.get_code()
        out += self.satisfied_table.get_code()

        # operator code
        for operator in self.operators:
            out += operator.get_code()
        return out

    def get_commands(self):
        commands = list()
        for operator in self.operators:
            # print str(operator)
            commands.extend(operator.get_commands())

        commands.append(self.out_header_table.get_default_command())
        commands.append(self.satisfied_table.get_default_command())

        return commands

    def get_metadata_name(self):
        return self.meta_init_name

    def get_header_format(self):
        # TODO: This will now change
        header_format = dict()
        header_format['parse_payload'] = self.parse_payload
        header_format['payload_fields'] = self.payload_fields
        header_format['reads_register'] = self.read_register
        header_format['filter_payload'] = self.filter_payload
        header_format['filter_payload_str'] = self.filter_payload_str
        header_format['registers'] = self.registers_to_read

        if self.out_header:
            header_format['headers'] = self.out_header
        else:
            header_format['headers'] = None

        return header_format

    def get_update_commands(self, filter_id, update):
        commands = list()
        if filter_id in self.src_to_filter_operator:
            filter_operator = self.src_to_filter_operator[filter_id]
            filter_mask = filter_operator.get_filter_mask()
            filter_table_name = filter_operator.table.get_name()
            filter_action = filter_operator.get_match_action()

            for dip in update:
                dip = dip.strip('\n')
                commands.append(
                    'table_add %s %s  %s/%i =>' %
                    (filter_table_name, filter_action, dip, filter_mask))
        return commands
    def __init__(self, qid, operator_id, meta_init_name, drop_action,
                 nop_action, keys, p4_raw_fields):
        super(P4Distinct, self).__init__('Distinct', qid, operator_id, keys,
                                         p4_raw_fields)

        self.threshold = 0
        self.comp_func = '<='  # bitwise and
        self.update_func = '&'  # bitwise and

        # create METADATA to store index and value
        fields = [('value', REGISTER_WIDTH),
                  ('index', REGISTER_NUM_INDEX_BITS)]
        self.metadata = MetaData(self.operator_name, fields)

        # create REGISTER to keep track of counts
        self.register = Register(self.operator_name, REGISTER_WIDTH,
                                 REGISTER_INSTANCE_COUNT)

        # Add map init
        hash_init_fields = list()
        for fld in self.keys:
            if fld == 'qid':
                hash_init_fields.append(
                    P4Field(layer=None,
                            target_name="qid",
                            sonata_name="qid",
                            size=QID_SIZE))
            elif fld == 'count':
                hash_init_fields.append(
                    P4Field(layer=None,
                            target_name="count",
                            sonata_name="count",
                            size=COUNT_SIZE))
            else:
                hash_init_fields.append(
                    self.p4_raw_fields.get_target_field(fld))
        # create HASH for access to register
        hash_fields = list()
        for field in hash_init_fields:
            if '/' in field.sonata_name:
                self.logger.error('found a / in the key')
                raise NotImplementedError
            else:
                hash_fields.append(
                    '%s.%s' %
                    (meta_init_name, field.target_name.replace(".", "_")))
        self.hash = HashFields(self.operator_name, hash_fields, 'crc16',
                               REGISTER_NUM_INDEX_BITS)

        # name of metadata field where the index of the count within the register is stored
        self.index_field_name = '%s.index' % self.metadata.get_name()
        # name of metadata field where the count is kept temporarily
        self.value_field_name = '%s.value' % self.metadata.get_name()

        # create ACTION and TABLE to compute hash and get value
        primitives1 = list()
        primitives1.append(
            ModifyFieldWithHashBasedOffset(self.index_field_name, 0,
                                           self.hash.get_name(),
                                           REGISTER_INSTANCE_COUNT))
        primitives1.append(
            RegisterRead(self.value_field_name, self.register.get_name(),
                         self.index_field_name))

        self.action1 = Action('do_init_%s' % self.operator_name, primitives1)

        # create ACTION and TABLE to bit_or value & write back
        primitives2 = list()
        primitives2.append(
            BitOr(self.value_field_name, self.value_field_name, 1))
        primitives2.append(
            RegisterWrite(self.register.get_name(), self.index_field_name,
                          self.value_field_name))
        self.action2 = Action('do_update_%s' % self.operator_name, primitives2)

        table_name = 'init_%s' % self.operator_name
        self.init_table = Table(table_name, self.action1.get_name(), [], None,
                                1)

        table_name = 'update_%s' % self.operator_name
        self.update_table = Table(table_name, self.action2.get_name(), [],
                                  None, 1)

        # create two TABLEs that implement reduce operation: if count <= THRESHOLD, update count and drop, else let it
        # pass through
        table_name = 'pass_%s' % self.operator_name
        self.pass_table = Table(table_name, nop_action, [], None, 1)
        table_name = 'drop_%s' % self.operator_name
        self.drop_table = Table(table_name, drop_action, [], None, 1)
class P4Filter(P4Operator):
    def __init__(self, qid, operator_id, keys, filter_keys, func, source,
                 match_action, miss_action, p4_raw_fields):
        super(P4Filter, self).__init__('Filter', qid, operator_id, keys,
                                       p4_raw_fields)

        self.filter_keys = filter_keys
        self.filter_mask = None
        self.filter_values = None
        self.func = None
        # self.out_headers = []
        self.match_action = match_action
        self.miss_action = miss_action

        self.source = source

        if not len(func) > 0 or func[0] == 'geq':
            self.logger.error(
                'Got the following func with the Filter Operator: %s' %
                (str(func), ))
            # raise NotImplementedError
        else:
            self.func = func[0]
            if func[0] == 'mask':
                self.filter_mask = func[1]
                self.filter_values = func[2:]
            elif func[0] == 'eq':
                self.filter_values = [func[1:]]

        reads_fields = list()
        for filter_key in self.filter_keys:
            if self.func == 'mask':
                reads_fields.append((filter_key, 'lpm'))
            else:
                reads_fields.append((filter_key, 'exact'))

        self.table = Table(self.operator_name, miss_action, (match_action, ),
                           reads_fields, TABLE_SIZE)

    def __repr__(self):
        return '.Filter(filter_keys=' + str(
            self.filter_keys) + ', func=' + str(self.func) + ', src = ' + str(
                self.source) + ')'

    def get_code(self):
        out = ''
        out += '// Filter %i of query %i\n' % (self.operator_id, self.query_id)
        out += self.table.get_code()
        out += '\n'
        return out

    def get_commands(self):
        commands = list()
        commands.append(self.table.get_default_command())
        if self.filter_values:
            for filter_value in self.filter_values:
                commands.append(
                    self.table.get_add_rule_command(self.match_action,
                                                    filter_value, None))
        return commands

    def get_control_flow(self, indent_level):
        indent = '\t' * indent_level
        out = ''
        out += '%sapply(%s);\n' % (indent, self.table.get_name())
        return out

    def get_match_action(self):
        return self.match_action

    def get_filter_mask(self):
        return self.filter_mask

    def get_init_keys(self):

        return self.keys
class P4MapInit(P4Operator):
    def __init__(self, qid, operator_id, keys, p4_raw_fields):
        super(P4MapInit, self).__init__('MapInit', qid, operator_id, keys,
                                        p4_raw_fields)

        # Add map init
        map_init_fields = list()
        for fld in self.keys:
            if fld == 'qid':
                map_init_fields.append(
                    P4Field(layer=None,
                            target_name="qid",
                            sonata_name="qid",
                            size=QID_SIZE))
            elif fld == 'count':
                map_init_fields.append(
                    P4Field(layer=None,
                            target_name="count",
                            sonata_name="count",
                            size=COUNT_SIZE))
            elif fld == 'index':
                map_init_fields.append(
                    P4Field(layer=None,
                            target_name="index",
                            sonata_name="index",
                            size=INDEX_SIZE))
            else:
                map_init_fields.append(
                    self.p4_raw_fields.get_target_field(fld))
        # create METADATA object to store data for all keys
        meta_fields = list()
        for fld in map_init_fields:
            meta_fields.append((fld.target_name.replace(".", "_"), fld.size))

        self.metadata = MetaData(self.operator_name, meta_fields)

        # create ACTION to initialize the metadata
        primitives = list()
        for fld in map_init_fields:
            sonata_name = fld.sonata_name
            target_name = fld.target_name
            meta_field_name = '%s.%s' % (self.metadata.get_name(),
                                         target_name.replace(".", "_"))

            if sonata_name == 'qid':
                # Assign query id to this field
                primitives.append(ModifyField(meta_field_name, qid))
            elif sonata_name == 'count':
                primitives.append(ModifyField(meta_field_name, 0))
            elif sonata_name == 'index':
                primitives.append(ModifyField(meta_field_name, 0))
            else:
                # Read data from raw header fields and assign them to these meta fields
                primitives.append(ModifyField(meta_field_name, target_name))

        self.action = Action('do_%s' % self.operator_name, primitives)

        # create dummy TABLE to execute the action
        self.table = Table(self.operator_name, self.action.get_name(), [],
                           None, 1)

    def __repr__(self):
        return '.MapInit(keys=' + str(self.keys) + ')'

    def get_meta_name(self):
        return self.metadata.get_name()

    def get_code(self):
        out = ''
        out += '// MapInit of query %i\n' % self.query_id
        out += self.metadata.get_code()
        out += self.action.get_code()
        out += self.table.get_code()
        out += '\n'
        return out

    def get_commands(self):
        commands = list()
        commands.append(self.table.get_default_command())
        return commands

    def get_control_flow(self, indent_level):
        indent = '\t' * indent_level
        out = ''
        out += '%sapply(%s);\n' % (indent, self.table.get_name())
        return out

    def get_init_keys(self):
        return self.keys
    def __init__(self, qid, operator_id, meta_init_name, drop_action, keys,
                 values, threshold, read_register, p4_raw_fields):
        super(P4Reduce, self).__init__('Reduce', qid, operator_id, keys,
                                       p4_raw_fields)

        if threshold == '-1':
            self.threshold = int(THRESHOLD)
        else:
            self.threshold = int(threshold)

        self.read_register = read_register

        if self.read_register:
            self.out_headers += ['index']
        else:
            self.out_headers += ['count']

        # create METADATA to store index and value
        fields = [('value', REGISTER_WIDTH),
                  ('index', REGISTER_NUM_INDEX_BITS)]
        self.metadata = MetaData(self.operator_name, fields)

        # create REGISTER to keep track of counts
        self.register = Register(self.operator_name, REGISTER_WIDTH,
                                 REGISTER_INSTANCE_COUNT)

        self.values = values

        # Add map init
        hash_init_fields = list()
        for fld in self.keys:
            if fld == 'qid':
                hash_init_fields.append(
                    P4Field(layer=None,
                            target_name="qid",
                            sonata_name="qid",
                            size=QID_SIZE))
            elif fld == 'count':
                hash_init_fields.append(
                    P4Field(layer=None,
                            target_name="count",
                            sonata_name="count",
                            size=COUNT_SIZE))
            elif fld == 'index':
                hash_init_fields.append(
                    P4Field(layer=None,
                            target_name="index",
                            sonata_name="index",
                            size=INDEX_SIZE))
            else:
                hash_init_fields.append(
                    self.p4_raw_fields.get_target_field(fld))
        # create HASH for access to register
        hash_fields = list()
        for field in hash_init_fields:
            if '/' in field.sonata_name:
                self.logger.error('found a / in the key')
                raise NotImplementedError
            else:
                hash_fields.append(
                    '%s.%s' %
                    (meta_init_name, field.target_name.replace(".", "_")))
        self.hash = HashFields(self.operator_name, hash_fields, 'crc16',
                               REGISTER_NUM_INDEX_BITS)

        # name of metadata field where the index of the count within the register is stored
        self.index_field_name = '%s.index' % self.metadata.get_name()
        # name of metadata field where the count is kept temporarily
        self.value_field_name = '%s.value' % self.metadata.get_name()

        # create ACTION and TABLE to compute hash and get value
        primitives = list()
        primitives.append(
            ModifyFieldWithHashBasedOffset(self.index_field_name, 0,
                                           self.hash.get_name(),
                                           REGISTER_INSTANCE_COUNT))
        primitives.append(
            RegisterRead(self.value_field_name, self.register.get_name(),
                         self.index_field_name))

        if self.values[0] == 'count':
            if self.threshold <= 1:
                self.threshold = '1'

            primitives.append(
                ModifyField(self.value_field_name,
                            '%s + %i' % (self.value_field_name, 1)))
        else:
            target_fld = self.p4_raw_fields.get_target_field(self.values[0])
            if self.threshold <= 1:
                self.threshold = '%s.%s' % (
                    meta_init_name, target_fld.target_name.replace(".", "_"))

            primitives.append(
                ModifyField(
                    self.value_field_name,
                    '%s + %s' % (self.value_field_name, '%s.%s' %
                                 (meta_init_name,
                                  target_fld.target_name.replace(".", "_")))))

        primitives.append(
            RegisterWrite(self.register.get_name(), self.index_field_name,
                          self.value_field_name))
        self.init_action = Action('do_init_%s' % self.operator_name,
                                  primitives)
        table_name = 'init_%s' % self.operator_name
        self.init_table = Table(table_name, self.init_action.get_name(), [],
                                None, 1)

        # create three TABLEs that implement reduce operation
        # if count <= THRESHOLD, update count and drop,
        table_name = 'drop_%s' % self.operator_name
        self.drop_table = Table(table_name, drop_action, [], None, 1)

        # if count == THRESHOLD, pass through with current count
        field_to_modified = None
        if not self.read_register:
            field_to_modified = ModifyField('%s.count' % meta_init_name,
                                            self.value_field_name)
        else:
            field_to_modified = ModifyField('%s.index' % meta_init_name,
                                            self.index_field_name)

        self.set_count_action = Action('set_count_%s' % self.operator_name,
                                       field_to_modified)
        table_name = 'first_pass_%s' % self.operator_name
        self.first_pass_table = Table(table_name,
                                      self.set_count_action.get_name(), [],
                                      None, 1)

        if not self.read_register:
            # if count > THRESHOLD, let it pass through with count set to 1
            self.reset_count_action = Action(
                'reset_count_%s' % self.operator_name,
                ModifyField('%s.count' % meta_init_name, 1))
            table_name = 'pass_%s' % self.operator_name
            self.pass_table = Table(table_name,
                                    self.reset_count_action.get_name(), [],
                                    None, 1)