Beispiel #1
0
class NCacheController(object):
    def __init__(self, sw_name):
        self.topo = Topology(db="./topology.db")
        self.sw_name = sw_name
        self.thrift_port = self.topo.get_thrift_port(self.sw_name)
        self.cpu_port = self.topo.get_cpu_port_index(self.sw_name)
        self.controller = SimpleSwitchAPI(self.thrift_port)

        self.custom_calcs = self.controller.get_custom_crc_calcs()
        self.sketch_register_num = len(self.custom_calcs)

        self.setup()

    def setup(self):
        if self.cpu_port:
            self.controller.mirroring_add(CONTROLLER_MIRROR_SESSION,
                                          self.cpu_port)

    # set a static allocation scheme for l2 forwarding where the mac address of
    # each host is associated with the port connecting this host to the switch
    def set_forwarding_table(self):
        for host in self.topo.get_hosts_connected_to(self.sw_name):
            port = self.topo.node_to_node_port_num(self.sw_name, host)
            host_mac = self.topo.get_host_mac(host)
            print str(host_mac) + str(port)
            self.controller.table_add("l2_forward", "set_egress_port",
                                      [str(host_mac)], [str(port)])

    def main(self):
        self.set_forwarding_table()
class LearningSwitchControllerApp(object):

    def __init__(self, switchName):
        self.topo = Topology(db="topology.db")
        self.switchName = switchName
        self.thrift_port = self.topo.get_thrift_port(switchName)
        self.cpu_port = self.topo.get_cpu_port_index(self.switchName)
        self.controller = SimpleSwitchAPI(self.thrift_port)

        self.init()

    def init(self):
        self.controller.reset_state()
        self.add_mcast_grp()
        self.add_mirror()

    def add_mirror(self):
        if self.cpu_port:
            self.controller.mirroring_add(MIRROR_SESSION_ID, self.cpu_port)

    def add_mcast_grp(self):
        interfaces_to_port = self.topo[self.switchName]["interfaces_to_port"].copy()
        # filter lo and cpu port
        interfaces_to_port.pop('lo', None)
        interfaces_to_port.pop(self.topo.get_cpu_port_intf(self.switchName), None)

        mc_grp_id = 1
        rid = 0
        # add multicast group
        self.controller.mc_mgrp_create(mc_grp_id)
        port_list = interfaces_to_port.values()[:]
        # add multicast node group
        handle = self.controller.mc_node_create(rid, port_list)
        # associate with mc grp
        self.controller.mc_node_associate(mc_grp_id, handle)

    def learn(self, learningData):
        for macAddr, ingressPort in learningData:
            print("macAddr: %012X ingressPort: %s ", macAddr, ingressPort)
            self.controller.table_add("srcMacAddr", "NoAction", [str(macAddr)])
            self.controller.table_add("dstMacAddr", "forward", [
                                      str(macAddr)], [str(ingressPort)])

    def recv_msg_cpu(self, pkt):

        packet = Ether(str(pkt))
        if packet.type == L2_LEARN_ETHER_TYPE:
            cpu_header = CpuHeader(bytes(packet.payload))
            self.learn([(cpu_header.macAddr, cpu_header.ingressPort)])

    def run_cpu_port_loop(self):

        cpu_port_intf = str(self.topo.get_cpu_port_intf(
            self.switchName).replace("eth0", "eth1"))
        sniff(iface=cpu_port_intf, prn=self.recv_msg_cpu)
class FloodingController(object):

    def __init__(self, sw_name):

        self.topo = Topology(db="topology.db")
        self.sw_name = sw_name
        self.thrift_port = self.topo.get_thrift_port(sw_name)
        self.cpu_port =  self.topo.get_cpu_port_index(self.sw_name)
        self.controller = SimpleSwitchAPI(self.thrift_port)
        self.init()

    def init(self):

        self.controller.reset_state()
        self.fill_dmac_table()
        self.add_boadcast_groups()

    def fill_dmac_table(self):
        self.controller.table_add("dmac", "forward", ['00:00:0a:00:00:01'], ['1'])
        self.controller.table_add("dmac", "forward", ['00:00:0a:00:00:02'], ['2'])
        self.controller.table_add("dmac", "forward", ['00:00:0a:00:00:03'], ['3'])
        self.controller.table_add("dmac", "forward", ['00:00:0a:00:00:04'], ['4'])
        self.controller.table_set_default("dmac", "broadcast", [])

    def add_boadcast_groups(self):

        interfaces_to_port = self.topo[self.sw_name]["interfaces_to_port"].copy()
        #filter lo and cpu port
        interfaces_to_port.pop('lo', None)
        interfaces_to_port.pop(self.topo.get_cpu_port_intf(self.sw_name), None)

        mc_grp_id = 1
        rid = 0
        for ingress_port in interfaces_to_port.values():

            port_list = interfaces_to_port.values()[:]
            del(port_list[port_list.index(ingress_port)])

            #add multicast group
            self.controller.mc_mgrp_create(mc_grp_id)

            #add multicast node group
            handle = self.controller.mc_node_create(rid, port_list)

            #associate with mc grp
            self.controller.mc_node_associate(mc_grp_id, handle)

            #fill broadcast table
            self.controller.table_add("select_mcast_grp", "set_mcast_grp", [str(ingress_port)], [str(mc_grp_id)])

            mc_grp_id +=1
            rid +=1
Beispiel #4
0
class L2Controller(object):

    def __init__(self, sw_name):

        self.topo = Topology(db="topology.db")
        self.sw_name = sw_name
        self.thrift_port = self.topo.get_thrift_port(sw_name)
        self.cpu_port =  self.topo.get_cpu_port_index(self.sw_name)
        self.controller = SimpleSwitchAPI(self.thrift_port)

        self.init()

    def init(self):

        self.controller.reset_state()
        self.add_boadcast_groups()
        self.add_mirror()
        #self.fill_table_test()

    def add_mirror(self):

        if self.cpu_port:
            self.controller.mirroring_add(100, self.cpu_port)

    def add_boadcast_groups(self):

        interfaces_to_port = self.topo[self.sw_name]["interfaces_to_port"].copy()
        #filter lo and cpu port
        interfaces_to_port.pop('lo', None)
        interfaces_to_port.pop(self.topo.get_cpu_port_intf(self.sw_name), None)

        mc_grp_id = 1
        rid = 0
        for ingress_port in interfaces_to_port.values():

            port_list = interfaces_to_port.values()[:]
            del(port_list[port_list.index(ingress_port)])

            #add multicast group
            self.controller.mc_mgrp_create(mc_grp_id)

            #add multicast node group
            handle = self.controller.mc_node_create(rid, port_list)

            #associate with mc grp
            self.controller.mc_node_associate(mc_grp_id, handle)

            #fill broadcast table
            self.controller.table_add("broadcast", "set_mcast_grp", [str(ingress_port)], [str(mc_grp_id)])

            mc_grp_id +=1
            rid +=1

    def fill_table_test(self):
        self.controller.table_add("dmac", "forward", ['00:00:0a:00:00:01'], ['1'])
        self.controller.table_add("dmac", "forward", ['00:00:0a:00:00:02'], ['2'])
        self.controller.table_add("dmac", "forward", ['00:00:0a:00:00:03'], ['3'])
        self.controller.table_add("dmac", "forward", ['00:00:0a:00:00:04'], ['4'])


    def learn(self, learning_data):

        for mac_addr, ingress_port in  learning_data:
            print "mac: %012X ingress_port: %s " % (mac_addr, ingress_port)
            self.controller.table_add("smac", "NoAction", [str(mac_addr)])
            self.controller.table_add("dmac", "forward", [str(mac_addr)], [str(ingress_port)])

    def unpack_digest(self, msg, num_samples):

        digest = []
        print len(msg), num_samples
        starting_index = 32
        for sample in range(num_samples):
            mac0, mac1, ingress_port = struct.unpack(">LHH", msg[starting_index:starting_index+8])
            starting_index +=8
            mac_addr = (mac0 << 16) + mac1
            digest.append((mac_addr, ingress_port))

        return digest

    def recv_msg_digest(self, msg):

        topic, device_id, ctx_id, list_id, buffer_id, num = struct.unpack("<iQiiQi",
                                                                          msg[:32])
        digest = self.unpack_digest(msg, num)
        self.learn(digest)

        #Acknowledge digest
        self.controller.client.bm_learning_ack_buffer(ctx_id, list_id, buffer_id)


    def run_digest_loop(self):

        sub = nnpy.Socket(nnpy.AF_SP, nnpy.SUB)
        notifications_socket = self.controller.client.bm_mgmt_get_info().notifications_socket
        sub.connect(notifications_socket)
        sub.setsockopt(nnpy.SUB, nnpy.SUB_SUBSCRIBE, '')

        while True:
            msg = sub.recv()
            self.recv_msg_digest(msg)

    def recv_msg_cpu(self, pkt):

        packet = Ether(str(pkt))

        if packet.type == 0x1234:
            cpu_header = CpuHeader(packet.payload)
            self.learn([(cpu_header.macAddr, cpu_header.ingress_port)])

    def run_cpu_port_loop(self):

        cpu_port_intf = str(self.topo.get_cpu_port_intf(self.sw_name).replace("eth0", "eth1"))
        sniff(iface=cpu_port_intf, prn=self.recv_msg_cpu)
Beispiel #5
0
class P4CLI(CLI):

    def __init__(self, *args, **kwargs):
        self.conf_file = kwargs.get("conf_file", None)
        self.import_last_modifications = {}

        self.last_compilation_state = False

        if not self.conf_file:
            log.warn("No configuration given to the CLI. P4 functionalities are disabled.")
        else:
            self.config = load_conf(self.conf_file)
            # class CLI from mininet.cli does not have config parameter, thus remove it
            kwargs.__delitem__("conf_file")
        CLI.__init__(self, *args, **kwargs)

    def failed_status(self):
        self.last_compilation_state = False
        return FAILED_STATUS

    def do_load_topo_conf(self, line= ""):

        """
        Updates the topo config
        Args:
            line:

        Returns:

        """
        args = line.split()
        if args:
            conf_file = args[0]
            self.conf_file = conf_file

        #re-load conf file
        self.config = load_conf(self.conf_file)

    def do_set_p4conf(self, line=""):
        """Updates configuration file location, and reloads it."""
        args = line.split()
        conf = args[0]
        if not os.path.exists(conf):
            warn('Configuratuion file %s does not exist' % conf)
            return
        self.conf_file = conf
        self.config = load_conf(conf)

    def do_test_p4(self, line=""):
        """Tests start stop functionalities."""
        self.do_p4switch_stop("s1")
        self.do_p4switch_start("s1")
        self.do_p4switch_reboot("s1")
        self.do_p4switches_reboot()

    def do_p4switch_stop(self, line=""):
        """Stop simple switch from switch namespace."""
        switch_name = line.split()
        if not switch_name or len(switch_name) > 1:
            error('usage: p4switch_stop <p4switch name>\n')
        else:
            switch_name = switch_name[0]
            if switch_name not in self.mn:
                error("p4switch %s not in the network\n" % switch_name)
            else:
                p4switch = self.mn[switch_name]
                p4switch.stop_p4switch()

    def do_p4switch_start(self, line=""):
        """Start again simple switch from namespace."""
        args = line.split()

        # check args validity
        if len(args) > 5:
            error('usage: p4switch_start <p4switch name> [--p4src <path>] [--cmds path]\n')
            return self.failed_status()

        switch_name = args[0]

        if switch_name not in self.mn:
            error('usage: p4switch_start <p4switch name> [--p4src <path>] [--cmds path]\n')
            return self.failed_status()

        p4switch = self.mn[switch_name]

        # check if switch is running
        if p4switch.check_switch_started():
            error('P4 Switch already running, stop it first: p4switch_stop %s \n' % switch_name)
            return self.failed_status()

        #load default configuration
        # mandatory defaults if not defined we should complain
        default_p4 = self.config.get("program", None)
        default_options = self.config.get("options", None)

        # non mandatory defaults.
        default_compiler = self.config.get("compiler", DEFAULT_COMPILER)

        default_config = {"program": default_p4, "options": default_options, "compiler": default_compiler}
        #merge with switch conf
        switch_conf = default_config.copy()
        switch_conf.update(self.config['topology']['switches'][switch_name])

        if "--p4src" in args:
            p4source_path = args[args.index("--p4src")+1]
            switch_conf['program'] = p4source_path
            # check if file exists
            if not os.path.exists(p4source_path):
                warn('File Error: p4source %s does not exist\n' % p4source_path)
                return self.failed_status()
            #check if its not a file
            if not os.path.isfile(p4source_path):
                warn('File Error: p4source %s is not a file\n' % p4source_path)
                return self.failed_status()

        p4source_path_source = switch_conf['program']

        # generate output file name
        output_file = p4source_path_source.replace(".p4", "") + ".json"

        program_flag = last_modified(p4source_path_source, output_file)
        includes_flag = check_imports_last_modified(p4source_path_source,
                                                    self.import_last_modifications)

        log.debug("%s %s %s %s\n" % (p4source_path_source, output_file, program_flag, includes_flag))

        if program_flag or includes_flag or (not self.last_compilation_state):
            # compile program
            try:
                compile_p4_to_bmv2(switch_conf)
                self.last_compilation_state = True
            except CompilationError:
                log.error('Compilation failed\n')
                return self.failed_status()

            # update output program
            p4switch.json_path = output_file

        # start switch
        p4switch.start()

        # load command entries
        if "--cmds" in args:
            commands_path = args[args.index("--cmds")+1]
            # check if file exists

        else:
            commands_path = switch_conf.get('cli_input', None)

        if commands_path:
            if not os.path.exists(commands_path):
                error('File Error: commands file %s does not exist\n' % commands_path)
                return self.failed_status()
            entries = read_entries(commands_path)
            add_entries(p4switch.thrift_port, entries)

        return SUCCESS_STATUS

    def do_printSwitches(self, line=""):
        """Print names of all switches."""
        for sw in self.mn.p4switches:
            print((sw.name))   

    def do_p4switches_reboot(self, line=""):
        """Reboot all P4 switches with new program.

        Note:
            If you provide a P4 source code or cmd, all switches will have the same.
        """
        self.config = load_conf(self.conf_file)

        for sw in self.mn.p4switches:
            switch_name = sw.name
            self.do_p4switch_stop(line=switch_name)

            tmp_line = switch_name + " " +line
            self.do_p4switch_start(line=tmp_line)

        #run scripts
        if isinstance(self.config.get('exec_scripts', None), list):
            for script in self.config.get('exec_scripts'):
                if script["reboot_run"]:
                    info("Exec Script: {}\n".format(script["cmd"]))
                    run_command(script["cmd"])

    def do_p4switch_reboot(self, line=""):
        """Reboot a P4 switch with a new program."""
        self.config = load_conf(self.conf_file)
        if not line or len(line.split()) > 5:
            error('usage: p4switch_reboot <p4switch name> [--p4src <path>] [--cmds path]\n')
        else:
            switch_name = line.split()[0]
            self.do_p4switch_stop(line=switch_name)
            self.do_p4switch_start(line=line)

    def do_pingset(self ,line=""):
        hosts_names = line.strip().split()
        hosts = [x for x in self.mn.hosts if x.name in hosts_names]
        self.mn.ping(hosts=hosts, timeout=1)


    def do_printNetInfo(self, line=""):
        """Prints Topology Info"""

        self.topo = Topology(db="topology.db")
   
        print("\n*********************")
        print("Network Information:")
        print("*********************\n")
        
        switches = self.topo.get_switches()

        for sw in sorted(switches.keys()):
            
            # skip linux bridge
            if sw == "sw-cpu":
                continue

            thrift_port = self.topo.get_thrift_port(sw)
            switch_id = self.topo[sw].get("sw_id", "N/A")
            cpu_index = self.topo.get_cpu_port_index(sw, quiet=True)
            header = "{}(thirft->{}, cpu_port->{})".format(sw, thrift_port, cpu_index)

            header2 = "{:>4} {:>15} {:>8} {:>20} {:>16} {:>8} {:>8} {:>8} {:>8} {:>8}".format("port", "intf", "node", "mac", "ip", "bw", "weight", "delay", "loss","queue")                                                                                     

            print(header)
            print((len(header2)*"-")) 
            print(header2)
            
            for intf,port_number  in sorted(list(self.topo.get_interfaces_to_port(sw).items()), key=lambda x: x[1]):
                if intf == "lo":
                    continue
                
                other_node = self.topo.get_interfaces_to_node(sw)[intf]
                mac = self.topo[sw][other_node]['mac']
                ip = self.topo[sw][other_node]['ip'].split("/")[0]
                bw = self.topo[sw][other_node]['bw']
                weight = self.topo[sw][other_node]['weight']
                delay = self.topo[sw][other_node]['delay']
                loss = self.topo[sw][other_node]['loss']
                queue_length = self.topo[sw][other_node]['queue_length']
                print(("{:>4} {:>15} {:>8} {:>20} {:>16} {:>8} {:>8} {:>8} {:>8} {:>8}".format(port_number, intf, other_node, mac, ip, bw, weight, delay, loss, queue_length)))

            print((len(header2)*"-")) 
            print("")

        # HOST INFO
        print("Hosts Info")

        header = "{:>4} {:>15} {:>8} {:>20} {:>16} {:>8} {:>8} {:>8} {:>8} {:>8}".format(
            "name", "intf", "node", "mac", "ip", "bw", "weight", "delay", "loss","queue")    
        
        print((len(header)*"-")) 
        print(header)

        for host in sorted(self.topo.get_hosts()):           
            for intf,port_number  in sorted(list(self.topo.get_interfaces_to_port(host).items()), key=lambda x: x[1]):
                
                other_node = self.topo.get_interfaces_to_node(host)[intf]
                mac = self.topo[host][other_node]['mac']
                ip = self.topo[host][other_node]['ip'].split("/")[0]
                bw = self.topo[host][other_node]['bw']
                weight = self.topo[host][other_node]['weight']
                delay = self.topo[host][other_node]['delay']
                loss = self.topo[host][other_node]['loss']
                queue_length = self.topo[host][other_node]['queue_length']
                print(("{:>4} {:>15} {:>8} {:>20} {:>16} {:>8} {:>8} {:>8} {:>8} {:>8}".format(host, intf, other_node, mac, ip, bw, weight, delay, loss, queue_length)))

        print((len(header)*"-")) 
        print("")

#def describe(self, sw_addr=None, sw_mac=None):
#    print "**********"
#    print "Network configuration for: %s" % self.name
#    print "Default interface: %s\t%s\t%s" %(
#        self.defaultIntf().name,
#        self.defaultIntf().IP(),
#        self.defaultIntf().MAC()
#    )
#    if sw_addr is not None or sw_mac is not None:
#        print "Default route to switch: %s (%s)" % (sw_addr, sw_mac)
#    print "**********"
#    
#def describe(self):
#    print "%s -> Thrift port: %d" % (self.name, self.thrift_port)
Beispiel #6
0
class NCacheController(object):
    def __init__(self, sw_name, vtables_num=8):
        self.topo = Topology(db="../p4/topology.db")
        self.sw_name = sw_name
        self.thrift_port = self.topo.get_thrift_port(self.sw_name)
        self.cpu_port = self.topo.get_cpu_port_index(self.sw_name)
        self.controller = SimpleSwitchAPI(self.thrift_port)

        self.custom_calcs = self.controller.get_custom_crc_calcs()
        self.sketch_register_num = len(self.custom_calcs)

        self.vtables = []
        self.vtables_num = vtables_num

        # create a pool of ids (as much as the total amount of keys)
        # this pool will be used to assign index to keys which will be
        # used to index the cached key counter and the validity register
        self.ids_pool = range(0, VTABLE_ENTRIES * VTABLE_SLOT_SIZE)

        # array of bitmap, which marks available slots per cache line
        # as 0 bits and occupied slots as 1 bits
        self.mem_pool = [0] * VTABLE_ENTRIES

        # number of memory slots used (useful for lfu eviction policy)
        self.used_mem_slots = 0

        # dictionary storing the value table index, bitmap and counter/validity
        # register index in the P4 switch that corresponds to each key
        self.key_map = {}

        self.setup()

        #self.out_of_band_test()

    def inform_server(self):
        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        try:
            sock.connect(UNIX_CHANNEL)
        except socket.error as msg:
            #print('Error: Unable to contact server for cache operation completion')
            return

        sock.sendall(CACHE_INSERT_COMPLETE)

    # reports the value of counters for each cached key
    # (used only for debugging purposes)
    def report_counters(self):
        for key, val in self.key_map.items():
            vt_idx, bitmap, key_idx = val

            res = self.controller.counter_read(CACHED_KEYS_COUNTER, key_idx)
            if res != 0:
                print("[COUNTER] key = " + key + " [ " + str(res.packets) +
                      " ]")

    # periodically reset registers pertaining to query statistics module of the
    # P4 switch (count-min sketch registers, bloom filters and counters)
    def periodic_registers_reset(self):
        t = threading.Timer(STATISTICS_REFRESH_INTERVAL,
                            self.periodic_registers_reset)
        t.daemon = True
        t.start()

        # before reseting registers check if the cache is utilized above a
        # threshold (e.g 80%) and evict keys using lfu policy if needed
        self.cache_lfu_eviction(threshold=0.8, sampling=0.2, to_remove=0.5)

        # reset bloom filter related registers
        for i in range(BLOOMF_REGISTERS_NUM):
            self.controller.register_reset(BLOOMF_REG_PREFIX + str(i + 1))

        # reset count min sketch related registers
        for i in range(SKETCH_REGISTERS_NUM):
            self.controller.register_reset(SKETCH_REG_PREFIX + str(i + 1))

        # reset counter register storing the query frequency of each cached item
        self.controller.counter_reset(CACHED_KEYS_COUNTER)

        print("[INFO]: Reset query statistics registers.")

    # the controller periodically checks if the memory used has exceeded a given threshold
    # (e.g 80 %) and if that is the case then it evicts keys according to an approximated
    # LFU policy inspired by REDIS (https://redis.io/topics/lru-cache))
    def cache_lfu_eviction(self, threshold=0.8, sampling=0.2, to_remove=0.5):

        # if the threshold has not been surpassed then nothing to do
        if self.used_mem_slots <= (threshold * len(self.mem_pool) *
                                   VTABLE_SLOT_SIZE):
            return

        n_samples = int(sampling * len(self.key_map.items()))

        samples = random.sample(self.key_map.items(), n_samples)

        # read the counter for each sample and store them in an array
        evict_list = []
        for key, val in samples:
            x, y, cnt_idx = self.key_map[key]
            counter = self.controller.counter_read(CACHED_KEYS_COUNTER,
                                                   cnt_idx).packets
            evict_list.append((key, counter))

        # sort the array and pick the smallest K-th counters and evict their keys
        # (this could be achieved more optimally by using quickselect)
        import operator
        evict_list.sort(key=operator.itemgetter(1))

        for i in range(int(to_remove * n_samples)):
            curr = evict_list[i]
            self.evict(curr[0])

    def setup(self):
        if self.cpu_port:
            self.controller.mirroring_add(CONTROLLER_MIRROR_SESSION,
                                          self.cpu_port)

        # create custom hash functions for count min sketch and bloom filters
        self.set_crc_custom_hashes()
        self.create_hashes()

        # set a daemon to periodically reset registers
        self.periodic_registers_reset()

        # spawn new thread to serve incoming udp connections
        # (i.e hot reports from the switch)
        #udp_t = threading.Thread(target=self.hot_reports_loop)
        #udp_t.start()

    def set_crc_custom_hashes(self):
        i = 0
        for custom_crc32, width in sorted(self.custom_calcs.items()):
            self.controller.set_crc32_parameters(custom_crc32,
                                                 crc32_polinomials[i],
                                                 0xffffffff, 0xffffffff, True,
                                                 True)
            i += 1

    def create_hashes(self):
        self.hashes = []
        for i in range(self.sketch_register_num):
            self.hashes.append(
                Crc(32, crc32_polinomials[i], True, 0xffffffff, True,
                    0xffffffff))

    # set a static allocation scheme for l2 forwarding where the mac address of
    # each host is associated with the port connecting this host to the switch
    def set_forwarding_table(self):
        for host in self.topo.get_hosts_connected_to(self.sw_name):
            port = self.topo.node_to_node_port_num(self.sw_name, host)
            host_mac = self.topo.get_host_mac(host)
            self.controller.table_add("l2_forward", "set_egress_port",
                                      [str(host_mac)], [str(port)])

    def set_value_tables(self):
        for i in range(self.vtables_num):
            self.controller.table_add("vtable_" + str(i),
                                      "process_array_" + str(i), ['1'], [])

    # this function manages the mapping between between slots in register arrays
    # and the cached items by implementing the First Fit algorithm described in
    # Memory Management section of 4.4.2 (netcache paper)
    def first_fit(self, key, value_size):

        n_slots = (value_size / (VTABLE_SLOT_SIZE + 1)) + 1
        if value_size <= 0:
            return None
        if key in self.key_map:
            return None

        for idx in range(len(self.mem_pool)):
            old_bitmap = self.mem_pool[idx]
            n_zeros = 8 - bin(old_bitmap).count("1")

            if n_zeros >= n_slots:
                cnt = 0
                bitmap = 0
                for i in reversed(range(8)):
                    if cnt >= n_slots:
                        break

                    if not self.bit_is_set(old_bitmap, i):
                        bitmap = bitmap | (1 << i)
                        cnt += 1

                # mark last n_slots 0 bits as 1 bits because we assigned
                # them to the new key and they are now allocated
                self.mem_pool[idx] = old_bitmap | bitmap

                self.used_mem_slots += bin(bitmap).count("1")

                return (idx, bitmap)

        return None

    # converts a list of 1s and 0s represented as strings and converts it
    # to a bitmap using bitwise operations (this intermediate representation
    # of a list of 1s and 0s is used to avoid low level bitwise logic inside
    # core implementation logic)
    def convert_to_bitmap(self, strlist, bitmap_len):
        bitmap = 0
        # supports only bitmaps with multiple of 8 bits size
        if bitmap_len % 8 != 0:
            return bitmap
        for i in strlist:
            bitmap = bitmap << 1
            bitmap = bitmap | int(i)

        return bitmap

    # this function checks whether the k-th bit of a given number is set
    def bit_is_set(self, n, k):
        if n & (1 << k):
            return True
        else:
            return False

    # given a key and its associated value, we update the lookup table on
    # the switch and we also update the value registers with the value
    # given as argument (stored in multiple slots)
    def insert(self, key, value, cont=True):
        # find where to put the value for given key
        mem_info = self.first_fit(key, len(value))

        # if key already exists or not space available then stop
        if mem_info == None:
            return

        vt_index, bitmap = mem_info

        # keep track of number of bytes of the value written so far
        cnt = 0

        # store the value of the key in the vtables of the switch while
        # incrementally storing a part of the value at each value table
        # if the correspoding bit of the bitmap is set
        for i in range(self.vtables_num):

            if self.bit_is_set(bitmap, self.vtables_num - i - 1):
                partial_val = value[cnt:cnt + VTABLE_SLOT_SIZE]
                self.controller.register_write(VTABLE_NAME_PREFIX + str(i),
                                               vt_index,
                                               self.str_to_int(partial_val))

                cnt += VTABLE_SLOT_SIZE

        # allocate an id from the pool to index the counter and validity register
        # (we take the last element of list because in python list is implemented
        # to optimize for inserting and removing elements from the end of the list)
        key_index = self.ids_pool.pop()

        # add the new key to the cache lookup table of the p4 switch
        self.controller.table_add(
            NETCACHE_LOOKUP_TABLE, "set_lookup_metadata",
            [str(self.str_to_int(key))],
            [str(bitmap), str(vt_index),
             str(key_index)])

        # mark cache entry for this key as valid
        self.controller.register_write("cache_status", key_index, 1)

        self.key_map[key] = vt_index, bitmap, key_index

        # inform the server about the successful cache insertion
        if cont:
            self.inform_server()

        print("Inserted key-value pair to cache: (" + key + "," + value + ")")

    # converts a string to a bytes representation and afterwards returns
    # its integer representation of width specified by argument int_width
    # (seems hacky due to restriction to use python2.7)
    def str_to_int(self, x, int_width=VTABLE_SLOT_SIZE):
        if len(x) > int_width:
            print "Error: Overflow while converting string to int"

        # add padding with 0x00 if input string size less than int_width
        bytearr = bytearray(int_width - len(x))
        bytearr.extend(x.encode('utf-8'))
        return struct.unpack(">Q", bytearr)[0]

    # given an arbitrary sized integer, the max width (in bits) of the integer
    # it returns the string representation of the number (also stripping it of
    # any '0x00' characters) (network byte order is assumed)
    def int_to_packed(self, int_val, max_width=128, word_size=32):
        num_words = max_width / word_size
        words = self.int_to_words(int_val, num_words, word_size)

        fmt = '>%dI' % (num_words)
        return struct.pack(fmt, *words).strip('\x00')

    # split up an arbitrary sized integer to words (needed to hack
    # around struct.pack limitation to convert to byte any integer
    # greater than 8 bytes)
    def int_to_words(self, int_val, num_words, word_size):
        max_int = 2**(word_size * num_words) - 1
        max_word_size = 2**word_size - 1
        words = []
        for _ in range(num_words):
            word = int_val & max_word_size
            words.append(int(word))
            int_val >>= word_size
        words.reverse()
        return words

    # update the value of the given key with the new value given as argument
    # (by allowing updates also to be done by the controller, the client is
    # also able to update keys with values bigger than the previous one)
    # in netcache paper this restriction is not resolved
    def update(self, key, value):
        # if key is not in cache then nothing to do
        if key not in self.key_map:
            return

        # update key-value pair by removing old pair and inserting new one
        self.evict(key)
        self.insert(key, value)

    # evict given key from the cache by deleting its associated entries in
    # action tables of the switch, by deallocating its memory space and by
    # marking the cache entry as valid once the deletion is completed
    def evict(self, key):

        if key not in self.key_map:
            return

        # delete entry from the lookup_table
        entry_handle = self.controller.get_handle_from_match(
            NETCACHE_LOOKUP_TABLE, [
                str(self.str_to_int(key)),
            ])

        if entry_handle is not None:
            self.controller.table_delete(NETCACHE_LOOKUP_TABLE, entry_handle)

        # delete mapping of key from controller's dictionary
        vt_idx, bitmap, key_idx = self.key_map[key]
        del self.key_map[key]

        # deallocate space from memory pool
        self.mem_pool[vt_idx] = self.mem_pool[vt_idx] ^ bitmap
        self.used_mem_slots = self.used_mem_slots - bin(bitmap).count("1")

        # free the id used to index the validity/counter register and append
        # it back to the id pool of the controller
        self.ids_pool.append(key_idx)

        # mark cache entry as valid again (should be the last thing to do)
        self.controller.register_write("cache_status", key_idx, 1)

    # used for testing purposes and static population of cache
    def dummy_populate_vtables(self):
        test_values_l = [
            "alpha", "beta", "gamma", "delta", "epsilon", "zeta", "hita",
            "theta", "yiota", "kappa", "lambda", "meta"
        ]
        test_keys_l = [
            "one", "two", "three", "four", "five", "six", "seven", "eight",
            "nine", "ten", "eleven", "twelve"
        ]
        cnt = 0
        for i in range(11):
            self.insert(test_keys_l[i], test_values_l[i], False)

    # handling reports from the switch corresponding to hot keys, updates to
    # key-value pairs or deletions - this function receives a packet, extracts
    # its netcache header and manipulates cache based on the operation field
    # of the netcache header (callback function)
    def recv_switch_updates(self, pkt):
        print("Received message from switch")

        # extract netcache header information
        if pkt.haslayer(UDP):
            ncache_header = NetcacheHeader(pkt[UDP].payload)
        elif pkt.haslayer(TCP):
            ncache_header = NetcacheHeader(pkt[TCP].payload)

        key = self.int_to_packed(ncache_header.key, max_width=128)
        value = self.int_to_packed(ncache_header.value, max_width=1024)

        op = ncache_header.op

        if op == NETCACHE_HOT_READ_QUERY:
            print("Received hot report for key = " + key)
            # if the netcache header has null value or if the "hot key"
            # reported doesn't exist then do not update cache
            if ncache_header.op == NETCACHE_KEY_NOT_FOUND:
                return

            self.insert(key, value)

        elif op == NETCACHE_DELETE_COMPLETE:
            print("Received query to delete key = " + key)
            self.evict(key)

        elif op == NETCACHE_UPDATE_COMPLETE:
            print("Received query to update key = " + key)
            self.update(key, value)

        else:
            print("Error: unrecognized operation field of netcache header")

    # sniff infinitely the interface connected to the P4 switch and when a valid netcache
    # packet is captured, handle the packet via a callback to recv_switch_updates function
    def hot_reports_loop(self):
        cpu_port_intf = str(self.topo.get_cpu_port_intf(self.sw_name))
        sniff(iface=cpu_port_intf,
              prn=self.recv_switch_updates,
              filter="port 50000")

    def main(self):
        self.set_forwarding_table()
        self.set_value_tables()
        self.dummy_populate_vtables()
        self.hot_reports_loop()
class BaseController(object):
    """A base P4 switch controller that your controllers probably want to inherit from.

    Implements the CPU loop. You must override the
    `recv_packet(self, packet)` method to use it.
    """
    def __init__(self, sw_name, topology_db_file="./topology.db"):
        self.topo = Topology(db=topology_db_file)
        # print(self.topo)
        self.sw_name = sw_name
        self.thrift_port = self.topo.get_thrift_port(sw_name)
        self.cpu_port = self.topo.get_cpu_port_index(self.sw_name)
        self.controller = SimpleSwitchAPIAsyncWrapper(self.thrift_port)

    @classmethod
    @defer.inlineCallbacks
    def get_initialised(cls, sw_name, *args, **kwargs):
        obj = cls(sw_name, *args, **kwargs)
        yield obj._before_init()

        # TODO this actually wasn't a great idea and I shouldn't be
        # doing it, because it breaks expectations. But it works :D
        for mcls in reversed(cls.__mro__):
            if 'init' in mcls.__dict__:
                yield defer.maybeDeferred(mcls.init, obj)

        defer.returnValue(obj)

    def _before_init(self):
        return self.controller.reset_state()

    def recv_packet(self, msg):
        raise NotImplementedError(
            "Packet from switch received, but recv_packet has not been implemented"
        )

    @defer.inlineCallbacks
    def _consume_from_packet_queue(self):
        msg = yield self.packet_queue.get()
        self.recv_packet(msg)
        reactor.callLater(0, self._consume_from_packet_queue)

    @print_method_call
    def start_sniffer_thread(self):
        self.packet_queue = defer.DeferredQueue()
        cpu_port_intf = str(
            self.topo.get_cpu_port_intf(self.sw_name).replace("eth0", "eth1"))
        self.sniffer_thread = SnifferThread(reactor, self.packet_queue,
                                            cpu_port_intf)
        self.sniffer_thread.daemon = True  # die when the main thread dies
        self.sniffer_thread.start()

        workers = 4
        for i in range(workers):
            self._consume_from_packet_queue()

    @defer.inlineCallbacks
    def init(self):
        """Reminder: init() is Special."""
        if self.cpu_port:
            yield self.controller.mirroring_add(
                p4settings['CPU_PORT_MIRROR_ID'], self.cpu_port)

    @classmethod
    def run(cls, sw_name):
        """Deprecated."""
        task.react((lambda reactor, sw_name: cls.get_initialised(sw_name)),
                   [sw_name])
Beispiel #8
0
class BlinkController:

    def __init__(self, topo_db, sw_name, ip_controller, port_controller, log_dir, \
        monitoring=True, routing_file=None):

        self.topo = Topology(db=topo_db)
        self.sw_name = sw_name
        self.thrift_port = self.topo.get_thrift_port(sw_name)
        self.cpu_port = self.topo.get_cpu_port_index(self.sw_name)
        self.controller = SimpleSwitchAPI(self.thrift_port)
        self.controller.reset_state()
        self.log_dir = log_dir

        print 'connecting to ', ip_controller, port_controller
        # Socket used to communicate with the controller
        self.sock_controller = socket.socket(socket.AF_INET,
                                             socket.SOCK_STREAM)
        server_address = (ip_controller, port_controller)
        self.sock_controller.connect(server_address)
        print 'Connected!'

        # Send the switch name to the controller
        self.sock_controller.sendall(str(sw_name))

        self.make_logging()

        if monitoring:
            # Monitoring scheduler
            self.t_sched = sched_timer.RepeatingTimer(10, 0.5, self.scheduling)
            self.t_sched.start()

        self.mapping_dic = {}
        tmp = list(self.topo.get_hosts()) + list(self.topo.get_p4switches())
        self.mapping_dic = {k: v for v, k in enumerate(tmp)}
        self.log.info(str(self.mapping_dic))

        self.routing_file = routing_file
        print 'routing_file ', routing_file
        if self.routing_file is not None:
            json_data = open(self.routing_file)
            self.topo_routing = json.load(json_data)

    def make_logging(self):
        # Logger for the pipeline
        logger.setup_logger('p4_to_controller', self.log_dir+'/p4_to_controller_'+ \
            str(self.sw_name)+'.log', level=logging.INFO)
        self.log = logging.getLogger('p4_to_controller')

        # Logger for the sliding window
        logger.setup_logger('p4_to_controller_sw', self.log_dir+'/p4_to_controller_'+ \
            str(self.sw_name)+'_sw.log', level=logging.INFO)
        self.log_sw = logging.getLogger('p4_to_controller_sw')

        # Logger for the rerouting
        logger.setup_logger('p4_to_controller_rerouting', self.log_dir+'/p4_to_controller_'+ \
            str(self.sw_name)+'_rerouting.log', level=logging.INFO)
        self.log_rerouting = logging.getLogger('p4_to_controller_rerouting')

        # Logger for the Flow Selector
        logger.setup_logger('p4_to_controller_fs', self.log_dir+'/p4_to_controller_'+ \
            str(self.sw_name)+'_fs.log', level=logging.INFO)
        self.log_fs = logging.getLogger('p4_to_controller_fs')

    def scheduling(self):

        for host in list(self.topo.get_hosts()):
            prefix = self.topo.get_host_ip(host) + '/24'

            # Print log about the sliding window
            for id_prefix in [
                    self.mapping_dic[host] * 2, self.mapping_dic[host] * 2 + 1
            ]:

                with HiddenPrints():
                    sw_time = float(
                        self.controller.register_read('sw_time',
                                                      index=id_prefix)) / 1000.
                    sw_index = self.controller.register_read('sw_index',
                                                             index=id_prefix)
                    sw_sum = self.controller.register_read('sw_sum',
                                                           index=id_prefix)
                self.log_sw.info('sw_time\t' + host + '\t' + prefix + '\t' +
                                 str(id_prefix) + '\t' + str(sw_time))
                self.log_sw.info('sw_index\t' + host + '\t' + prefix + '\t' +
                                 str(id_prefix) + '\t' + str(sw_index))

                if sw_sum >= 32:
                    self.log_sw.info('sw_sum\t' + host + '\t' + prefix + '\t' +
                                     str(id_prefix) + '\t' + str(sw_sum) +
                                     '\tREROUTING')
                else:
                    self.log_sw.info('sw_sum\t' + host + '\t' + prefix + '\t' +
                                     str(id_prefix) + '\t' + str(sw_sum))

                sw = []
                tmp = 'sw ' + host + ' ' + prefix + ' ' + str(id_prefix) + '\t'
                for i in range(0, 10):
                    with HiddenPrints():
                        binvalue = int(
                            self.controller.register_read(
                                'sw', (id_prefix * 10) + i))
                    tmp = tmp + str(binvalue) + ','
                    sw.append(binvalue)
                tmp = tmp[:-1]
                self.log_sw.info(str(tmp))

        # Print log about rerouting
        for host in list(self.topo.get_hosts()):
            prefix = self.topo.get_host_ip(host) + '/24'

            for id_prefix in [
                    self.mapping_dic[host] * 2, self.mapping_dic[host] * 2 + 1
            ]:

                with HiddenPrints():
                    nh_avaibility_1 = self.controller.register_read(
                        'nh_avaibility_1', index=id_prefix)
                    nh_avaibility_2 = self.controller.register_read(
                        'nh_avaibility_2', index=id_prefix)
                    nh_avaibility_3 = self.controller.register_read(
                        'nh_avaibility_3', index=id_prefix)
                    nbflows_progressing_2 = self.controller.register_read(
                        'nbflows_progressing_2', index=id_prefix)
                    nbflows_progressing_3 = self.controller.register_read(
                        'nbflows_progressing_3', index=id_prefix)
                    rerouting_ts = self.controller.register_read(
                        'rerouting_ts', index=id_prefix)
                    threshold = self.controller.register_read(
                        'threshold_registers', index=id_prefix)

                self.log_rerouting.info('nh_avaibility\t'+host+'\t'+prefix+'\t'+ \
                str(id_prefix)+'\t'+str(nh_avaibility_1)+'\t'+ \
                str(nh_avaibility_2)+'\t'+str(nh_avaibility_3))
                self.log_rerouting.info('nblows_progressing\t'+host+'\t'+prefix+'\t'+ \
                str(id_prefix)+'\t'+str(nbflows_progressing_2)+'\t'+ \
                str(nbflows_progressing_3))
                self.log_rerouting.info('rerouting_ts\t'+host+'\t'+prefix+'\t'+ \
                str(id_prefix)+'\t'+str(rerouting_ts))
                self.log_rerouting.info('threshold\t'+host+'\t'+prefix+'\t'+ \
                str(id_prefix)+'\t'+str(threshold))

                nexthop_str = ''
                nha = [nh_avaibility_1, nh_avaibility_2, nh_avaibility_3]
                i = 0
                if self.routing_file is not None:
                    bgp_type = 'customer' if id_prefix % 2 == 0 else 'customer_provider_peer'
                    if bgp_type not in self.topo_routing['switches'][
                            self.sw_name]['prefixes'][host]:
                        nexthop_str = 'NoPathAvailable'
                    else:
                        if len(self.topo_routing['switches'][self.sw_name]
                               ['prefixes'][host][bgp_type]) == 2:
                            self.topo_routing['switches'][self.sw_name][
                                'prefixes'][host][bgp_type].append(
                                    self.topo_routing['switches'][self.sw_name]
                                    ['prefixes'][host][bgp_type][-1])
                        for nexthop in self.topo_routing['switches'][
                                self.sw_name]['prefixes'][host][bgp_type]:
                            tmp = 'y' if nha[i] == 0 else 'n'
                            nexthop_str = nexthop_str + str(
                                nexthop) + '(' + tmp + ')\t'
                            i += 1
                        nexthop_str = nexthop_str[:-1]
                self.log_rerouting.info('nexthop\t'+host+'\t'+prefix+'\t'+ \
                str(id_prefix)+'\t'+str(nexthop_str))

        # Print log about the flow selector
        for host in list(self.topo.get_hosts()):
            prefix = self.topo.get_host_ip(host) + '/24'

            for id_prefix in [
                    self.mapping_dic[host] * 2, self.mapping_dic[host] * 2 + 1
            ]:

                sw = []
                tmp = 'fs_key ' + host + ' ' + prefix + ' ' + str(
                    id_prefix) + '\t'
                for i in range(0, 64):
                    with HiddenPrints():
                        binvalue = int(
                            self.controller.register_read(
                                'flowselector_key', 64 * id_prefix + i))
                    tmp = tmp + str(binvalue) + ','
                    sw.append(binvalue)
                tmp = tmp[:-1]
                self.log_fs.info(str(tmp))

                sw = []
                tmp = 'fs ' + host + ' ' + prefix + ' ' + str(id_prefix) + '\t'
                for i in range(0, 64):
                    with HiddenPrints():
                        binvalue = int(
                            self.controller.register_read(
                                'flowselector_ts', 64 * id_prefix + i))
                    tmp = tmp + str(binvalue) + ','
                    sw.append(binvalue)
                tmp = tmp[:-1]
                self.log_fs.info(str(tmp))

                sw = []
                tmp = 'fs_last_ret ' + host + ' ' + prefix + ' ' + str(
                    id_prefix) + '\t'
                for i in range(0, 64):
                    with HiddenPrints():
                        binvalue = int(
                            self.controller.register_read(
                                'flowselector_last_ret', 64 * id_prefix + i))
                    tmp = tmp + str(binvalue) + ','
                    sw.append(binvalue)
                tmp = tmp[:-1]
                self.log_fs.info(str(tmp))

                sw = []
                tmp = 'fs_last_ret_bin ' + host + ' ' + prefix + ' ' + str(
                    id_prefix) + '\t'
                for i in range(0, 64):
                    with HiddenPrints():
                        binvalue = int(
                            self.controller.register_read(
                                'flowselector_last_ret_bin',
                                64 * id_prefix + i))
                    tmp = tmp + str(binvalue) + ','
                    sw.append(binvalue)
                tmp = tmp[:-1]
                self.log_fs.info(str(tmp))

                sw = []
                tmp = 'fs_fwloops ' + host + ' ' + prefix + ' ' + str(
                    id_prefix) + '\t'
                for i in range(0, 64):
                    with HiddenPrints():
                        binvalue = int(
                            self.controller.register_read(
                                'flowselector_fwloops', 64 * id_prefix + i))
                    tmp = tmp + str(binvalue) + ','
                    sw.append(binvalue)
                tmp = tmp[:-1]
                self.log_fs.info(str(tmp))

                sw = []
                tmp = 'fs_correctness ' + host + ' ' + prefix + ' ' + str(
                    id_prefix) + '\t'
                for i in range(0, 64):
                    with HiddenPrints():
                        binvalue = int(
                            self.controller.register_read(
                                'flowselector_correctness',
                                64 * id_prefix + i))
                    tmp = tmp + str(binvalue) + ','
                    sw.append(binvalue)
                tmp = tmp[:-1]
                self.log_fs.info(str(tmp))

    def forwarding(self):
        p4switches = self.topo.get_p4switches()
        interfaces_to_node = p4switches[self.sw_name]['interfaces_to_node']

        for k, v in interfaces_to_node.items():

            try:
                dst_mac = self.topo.get_hosts()[v][self.sw_name]['mac']
            except KeyError:
                dst_mac = self.topo.get_p4switches()[v][self.sw_name]['mac']

            src_mac = p4switches[self.sw_name][v]['mac']
            outport = p4switches[self.sw_name]['interfaces_to_port'][
                p4switches[self.sw_name][v]['intf']]

            self.log.info('table add send set_nh ' + str(self.mapping_dic[v]) +
                          ' => ' + str(outport) + ' ' + str(src_mac) + ' ' +
                          str(dst_mac))
            self.controller.table_add(
                'send', 'set_nh', [str(self.mapping_dic[v])],
                [str(outport), str(src_mac),
                 str(dst_mac)])

    def run(self):

        sock_list = [self.sock_controller]
        controller_data = ''

        while True:
            inready, outready, excepready = select.select(sock_list, [], [])

            for sock in inready:
                if sock == self.sock_controller:
                    data_tmp = ''
                    toreturn = None

                    try:
                        data_tmp = sock.recv(100000000)
                    except socket.error, e:
                        err = e.args[0]
                        if not (err == errno.EAGAIN
                                or err == errno.EWOULDBLOCK):
                            print 'p4_to_controller: ', e
                            sock.close()
                            sock = None

                    if len(data_tmp) > 0:
                        controller_data += data_tmp

                        next_data = ''
                        while len(controller_data
                                  ) > 0 and controller_data[-1] != '\n':
                            next_data = controller_data[-1] + next_data
                            controller_data = controller_data[:-1]

                        toreturn = controller_data
                        controller_data = next_data

                    if toreturn is not None:
                        for line in toreturn.split('\n'):
                            if line.startswith('table add '):
                                line = line.rstrip('\n').replace(
                                    'table add ', '')

                                fwtable_name = line.split(' ')[0]
                                action_name = line.split(' ')[1]

                                match_list = line.split(' => ')[0].split(
                                    ' ')[2:]
                                action_list = line.split(' => ')[1].split(' ')

                                print line
                                print fwtable_name, action_name, match_list, action_list

                                self.log.info(line)
                                self.controller.table_add(fwtable_name, action_name, \
                                    match_list, action_list)

                            if line.startswith('do_register_write'):
                                line = line.rstrip('\n')
                                linetab = line.split(' ')

                                register_name = linetab[1]
                                index = int(linetab[2])
                                value = int(linetab[3])

                                self.log.info(line)
                                self.controller.register_write(register_name, \
                                    index, value)

                            if line.startswith('reset_states'):
                                self.log.info('RESETTING_STATES')

                                # First stop the scheduler to avoid concurrent used
                                # of the Thirft server
                                self.t_sched.cancel()
                                while self.t_sched.running:  # Wait the end of the log printing
                                    time.sleep(0.5)

                                time.sleep(1)

                                # Reset the state of the switch
                                self.controller.register_reset(
                                    'nh_avaibility_1')
                                self.controller.register_reset(
                                    'nh_avaibility_2')
                                self.controller.register_reset(
                                    'nh_avaibility_3')
                                self.controller.register_reset(
                                    'nbflows_progressing_2')
                                self.controller.register_reset(
                                    'nbflows_progressing_3')
                                self.controller.register_reset('rerouting_ts')
                                self.controller.register_reset(
                                    'timestamp_reference')
                                self.controller.register_reset('sw_time')
                                self.controller.register_reset('sw_index')
                                self.controller.register_reset('sw_sum')
                                self.controller.register_reset('sw')
                                self.controller.register_reset(
                                    'flowselector_key')
                                self.controller.register_reset(
                                    'flowselector_nep')
                                self.controller.register_reset(
                                    'flowselector_ts')
                                self.controller.register_reset(
                                    'flowselector_last_ret')
                                self.controller.register_reset(
                                    'flowselector_last_ret_bin')
                                self.controller.register_reset(
                                    'flowselector_correctness')
                                self.controller.register_reset(
                                    'flowselector_fwloops')

                                print self.sw_name, ' RESET.'

                                # Restart the scheduler
                                time.sleep(1)
                                self.t_sched.start()
class packetReceicer(threading.Thread):
    def __init__(self, sw_name, program):
        threading.Thread.__init__(self)
        if program == "f":
            self.topo = Topology(
                db="../p4src_flowsize/topology.db")  #set the topology
        elif program == "i":
            self.topo = Topology(
                db="../p4src_interval/topology.db")  #set the topology
        self.sw_name = sw_name
        self.thrift_port = self.topo.get_thrift_port(sw_name)
        self.cpu_port = self.topo.get_cpu_port_index(self.sw_name)
        self.controller = SimpleSwitchAPI(self.thrift_port)
        self.flow = {}
        self.flag = True
        self.init()

    def init(self):
        self.add_mirror()
        self.counter = 1
        self.logs = open("../switch_log/" + self.sw_name + ".log", "w")
        self.logs_info = open("../switch_log/" + self.sw_name + "_info.log",
                              "w")
        self.logs_info.write("SWITCH[" + self.sw_name + "]\n")
        self.logs.close()
        self.logs_info.close()

    def add_mirror(self):
        if self.cpu_port:
            self.controller.mirroring_add(
                100, self.cpu_port)  # correspond to the 100 in p4 code
            #is there any probability to increase the mirro port to add cpu port?

    def recv_msg_cpu(self, pkt):
        ## console output starts
        #print
        #print("["+self.sw_name+"] received packet number:"+str(self.counter))
        self.counter += 1
        cpu = CPU(str(pkt))
        #ls(cpu)

        ## console output ends
        type = (cpu.flags >> 2)
        if self.flag == True:
            logs = open("../switch_log/" + self.sw_name + ".log", "w")
            self.flag = False
            if type == 0:
                logs.write("flowsize information collecting\n")
            else:
                logs.write("interval information collecting\n")

            logs.close()

        self.gen_per_packet_log(cpu)
        self.collect_log(cpu)
        if (self.counter % 1000 == 0):
            self.gen_log()

    def gen_log(self):
        logs_info = open("../switch_log/" + self.sw_name + "_info.log", "a")
        logs_info.write("[flow number: " + str(len(self.flow)) + "]\n")
        change = lambda x: '.'.join(
            [str(x / (256**i) % 256) for i in range(3, -1, -1)])

        cnt = 0
        for i in self.flow:
            cnt += self.flow[i]["packnum"]
            tmp = i.split(":")
            tmp[0] = change(int(tmp[0]))
            tmp[1] = change(int(tmp[1]))
            tmp = " : ".join(tmp)
            logs_info.write("flow " + tmp + " ")

            logs_info.write(str(sorted(self.flow[i].items())))
            logs_info.write("\n")
        logs_info.write("[packet number sum:" + str(cnt) + "]\n\n")

        logs_info.close()

    def collect_log(self, cpu):
        flow_key = str(cpu.srcAddr) + ":" + str(cpu.dstAddr) + ":" + str(
            cpu.protocol) + ":" + str(cpu.srcPort) + ":" + str(cpu.dstPort)
        if self.flow.has_key(flow_key):
            self.flow[flow_key]["packnum"] += 1
            self.flow[flow_key][self.get_lev(cpu.delay)] += 1
        else:
            self.flow[flow_key]={"packnum":1,"0->1":0,"1->2":0,\
                "2->3":0,"3->4":0,"4->5":0,"5->6":0,"6->7":0\
                ,"7+":0}#"7->8":0,"8->9":0,"9+":0}
            self.flow[flow_key][self.get_lev(cpu.delay)] += 1

    def get_lev(self, delay):
        time_interval = 1000
        if delay < time_interval * 1:
            return "0->1"
        elif delay < time_interval * 2:
            return "1->2"
        elif delay < time_interval * 3:
            return "2->3"
        elif delay < time_interval * 4:
            return "3->4"
        elif delay < time_interval * 5:
            return "4->5"
        elif delay < time_interval * 6:
            return "5->6"
        elif delay < time_interval * 7:
            return "6->7"
        # elif delay<time_interval*8:
        #     return "7->8"
        # elif delay<time_interval*9:
        #     return "8->9"
        else:
            return "7+"

    def gen_per_packet_log(self, cpu):
        logs = open("../switch_log/" + self.sw_name + ".log", "a")
        change = lambda x: '.'.join(
            [str(x / (256**i) % 256) for i in range(3, -1, -1)])

        srcAddr = change(cpu.srcAddr)
        dstAddr = change(cpu.dstAddr)
        tmp_delay = str(cpu.delay)
        delay = tmp_delay[-9:-6] + "s " + tmp_delay[-6:-3] + "ms " + tmp_delay[
            -3:] + "us"
        tmp_interval = str(cpu.interval)
        interval = tmp_interval[-9:-6] + "s " + tmp_interval[
            -6:-3] + "ms " + tmp_interval[-3:] + "us"
        sketch_fg = (cpu.flags >> 1) & 0x1
        has_SFH = cpu.flags & 0x1
        type = (cpu.flags >> 2) & 0x1

        logs.write('{"switch name":"' + self.sw_name + '",')
        logs.write('"packet number":"' + str(self.counter - 1) +
                   '","packet_info":{')
        logs.write('"srcAddr":"' + str(srcAddr) + '",')
        logs.write('"dstAddr":"' + str(dstAddr) + '",')
        logs.write('"protocol":"' + str(cpu.protocol) + '",')
        logs.write('"srcPort":"' + str(cpu.srcPort) + '",')
        logs.write('"dstPort":"' + str(cpu.dstPort) + '",')
        logs.write('"delay ":"' + delay + '",')
        logs.write('"interval":"' + interval)
        logs.write('"timestamp":' + str(time.time()))
        if type == 0:
            logs.write('",' + '"using sketch":"' + str(sketch_fg) + '",')
            logs.write('"bring SFH":' + str(bool(has_SFH)))
        else:
            logs.write('",' + '"using sketch":"' + str(sketch_fg) + '",')
            logs.write('"bring MIH":' + str(bool(has_SFH)))
        logs.write(" }}\n")
        logs.close()

    def run_cpu_port_loop(self):
        cpu_port_intf = str(
            self.topo.get_cpu_port_intf(self.sw_name).replace("eth0", "eth1"))
        #the cpu has two ports   could use two thread to sniff
        print(cpu_port_intf)
        print
        print(sniff(iface=cpu_port_intf, prn=self.recv_msg_cpu))

    def run(self):
        self.run_cpu_port_loop()
class L2Controller(object):
    def __init__(self, sw_name):
        self.topo = Topology(db="topology.db")
        self.sw_name = sw_name
        self.thrift_port = self.topo.get_thrift_port(sw_name)
        self.cpu_port = self.topo.get_cpu_port_index(self.sw_name)
        self.controller = SimpleSwitchAPI(self.thrift_port)

        self.init()

    def init(self):
        self.controller.reset_state()
        self.add_boadcast_groups()
        self.add_mirror()

    def add_mirror(self):
        if self.cpu_port:
            self.controller.mirroring_add(100, self.cpu_port)

    def add_boadcast_groups(self):
        interfaces_to_port = self.topo[
            self.sw_name]["interfaces_to_port"].copy()
        # filter lo and cpu port
        interfaces_to_port.pop('lo', None)
        interfaces_to_port.pop(self.topo.get_cpu_port_intf(self.sw_name), None)

        mc_grp_id = 1
        rid = 0
        for ingress_port in interfaces_to_port.values():

            port_list = interfaces_to_port.values()[:]
            del (port_list[port_list.index(ingress_port)])

            #add multicast group
            self.controller.mc_mgrp_create(mc_grp_id)

            #add multicast node group
            handle = self.controller.mc_node_create(rid, port_list)

            #associate with mc grp
            self.controller.mc_node_associate(mc_grp_id, handle)

            #fill broadcast table
            self.controller.table_add("broadcast", "set_mcast_grp",
                                      [str(ingress_port)], [str(mc_grp_id)])

            mc_grp_id += 1
            rid += 1

    def learn_route(self, learning_data):
        for mac_addr, ingress_port in learning_data:
            print "mac: %012X ingress_port: %s " % (mac_addr, ingress_port)
            self.controller.table_add("smac", "NoAction", [str(mac_addr)])
            self.controller.table_add("dmac", "forward", [str(mac_addr)],
                                      [str(ingress_port)])

    def learn_connection(self, srcA, dstA, srcP, dstP):
        print("========== UPDATING CONNECTION ==========")
        connection = srcA
        connection = connection << 32
        connection = connection | dstA
        connection = connection << 16
        connection = connection | srcP
        connection = connection << 16
        connection = connection | dstP
        self.controller.table_add("tcp_forward", "NoAction", [str(connection)],
                                  [])

        connection = dstA
        connection = connection << 32
        connection = connection | srcA
        connection = connection << 16
        connection = connection | dstP
        connection = connection << 16
        connection = connection | srcP
        self.controller.table_add("tcp_forward", "NoAction", [str(connection)],
                                  [])

        print("========== UPDATE FINISHED ==========")

    def recv_msg_cpu(self, pkt):
        packet = Ether(str(pkt))

        if packet.type == 0x1234:
            learning = CpuRoute(packet.payload)
            print("got a packet of type route")
            self.learn_route([(learning.macAddr, learning.ingress_port)])
        if packet.type == 0xF00D:
            learning = CpuCookie(packet.payload)
            print("got a packet of type cookie")
            self.learn_connection(learning.srcAddr, learning.dstAddr,
                                  learning.srcPort, learning.dstPort)

    def run_cpu_port_loop(self):
        cpu_port_intf = str(
            self.topo.get_cpu_port_intf(self.sw_name).replace("eth0", "eth1"))
        sniff(iface=cpu_port_intf, prn=self.recv_msg_cpu)