示例#1
0
class RoutingController(object):
    def __init__(self):

        self.topo = Topology(db="topology.db")
        self.controllers = {}
        self.init()

    def init(self):
        self.connect_to_switches()
        self.reset_states()
        self.set_table_defaults()

    def reset_states(self):
        [controller.reset_state() for controller in self.controllers.values()]

    def connect_to_switches(self):
        for p4switch in self.topo.get_p4switches():
            thrift_port = self.topo.get_thrift_port(p4switch)
            self.controllers[p4switch] = SimpleSwitchAPI(thrift_port)

    def set_table_defaults(self):
        for sw_name, controller in self.controllers.items():
            controller.table_set_default("ipv4_lpm", "drop", [])
            controller.table_set_default("ecmp_group_to_nhop", "drop", [])
            controller.table_set_default(
                "sw_name", "set_swname",
                [str(int(sw_name.encode("hex"), base=16))])

    def set_egress_type_table(self):
        for node, controller in self.controllers.items():
            neighbor_list = self.topo.get_neighbors(node)
            for neighbor in neighbor_list:
                dst_type = 0
                if (self.topo.is_host(neighbor)):
                    dst_type = 1
                else:
                    dst_type = 2
                port_num = self.topo.node_to_node_port_num(node, neighbor)
                controller.table_add("dst_type_table", "set_dst_type",
                                     [str(port_num)], [str(dst_type)])

    def add_mirroring_ids(self):
        for node, controller in self.controllers.items():
            controller.mirroring_add(100, 1)

    def route(self):

        switch_ecmp_groups = {
            sw_name: {}
            for sw_name in self.topo.get_p4switches().keys()
        }

        for sw_name, controller in self.controllers.items():
            for sw_dst in self.topo.get_p4switches():

                #if its ourselves we create direct connections
                if sw_name == sw_dst:
                    for host in self.topo.get_hosts_connected_to(sw_name):
                        sw_port = self.topo.node_to_node_port_num(
                            sw_name, host)
                        host_ip = self.topo.get_host_ip(host) + "/32"
                        host_mac = self.topo.get_host_mac(host)

                        #add rule
                        print "table_add at {}:".format(sw_name)
                        self.controllers[sw_name].table_add(
                            "ipv4_lpm", "set_nhop", [str(host_ip)],
                            [str(host_mac), str(sw_port)])

                #check if there are directly connected hosts
                else:
                    if self.topo.get_hosts_connected_to(sw_dst):
                        paths = self.topo.get_shortest_paths_between_nodes(
                            sw_name, sw_dst)
                        for host in self.topo.get_hosts_connected_to(sw_dst):

                            if len(paths) == 1:
                                next_hop = paths[0][1]

                                host_ip = self.topo.get_host_ip(host) + "/24"
                                sw_port = self.topo.node_to_node_port_num(
                                    sw_name, next_hop)
                                dst_sw_mac = self.topo.node_to_node_mac(
                                    next_hop, sw_name)

                                #add rule
                                print "table_add at {}:".format(sw_name)
                                self.controllers[sw_name].table_add(
                                    "ipv4_lpm", "set_nhop", [str(host_ip)],
                                    [str(dst_sw_mac),
                                     str(sw_port)])

                            elif len(paths) > 1:
                                next_hops = [x[1] for x in paths]
                                dst_macs_ports = [
                                    (self.topo.node_to_node_mac(
                                        next_hop, sw_name),
                                     self.topo.node_to_node_port_num(
                                         sw_name, next_hop))
                                    for next_hop in next_hops
                                ]
                                host_ip = self.topo.get_host_ip(host) + "/24"

                                #check if the ecmp group already exists. The ecmp group is defined by the number of next
                                #ports used, thus we can use dst_macs_ports as key
                                if switch_ecmp_groups[sw_name].get(
                                        tuple(dst_macs_ports), None):
                                    ecmp_group_id = switch_ecmp_groups[
                                        sw_name].get(tuple(dst_macs_ports),
                                                     None)
                                    print "table_add at {}:".format(sw_name)
                                    self.controllers[sw_name].table_add(
                                        "ipv4_lpm", "ecmp_group",
                                        [str(host_ip)], [
                                            str(ecmp_group_id),
                                            str(len(dst_macs_ports))
                                        ])

                                #new ecmp group for this switch
                                else:
                                    new_ecmp_group_id = len(
                                        switch_ecmp_groups[sw_name]) + 1
                                    switch_ecmp_groups[sw_name][tuple(
                                        dst_macs_ports)] = new_ecmp_group_id

                                    #add group
                                    for i, (mac,
                                            port) in enumerate(dst_macs_ports):
                                        print "table_add at {}:".format(
                                            sw_name)
                                        self.controllers[sw_name].table_add(
                                            "ecmp_group_to_nhop", "set_nhop",
                                            [str(new_ecmp_group_id),
                                             str(i)],
                                            [str(mac), str(port)])

                                    #add forwarding rule
                                    print "table_add at {}:".format(sw_name)
                                    self.controllers[sw_name].table_add(
                                        "ipv4_lpm", "ecmp_group",
                                        [str(host_ip)], [
                                            str(new_ecmp_group_id),
                                            str(len(dst_macs_ports))
                                        ])

    def main(self):
        self.set_egress_type_table()
        self.add_mirroring_ids()
        self.route()
示例#2
0
class RerouteController(object):
    """Controller for the fast rerouting exercise."""
    def __init__(self):
        """Initializes the topology and data structures."""

        if not os.path.exists("topology.db"):
            print "Could not find topology object!\n"
            raise Exception

        self.topo = Topology(db="topology.db")
        self.controllers = {}
        self.connect_to_switches()
        self.reset_states()

        # Preconfigure all MAC addresses
        self.install_macs()

        # Install nexthop indices and populate registers.
        self.install_nexthop_indices()
        self.update_nexthops()

    def connect_to_switches(self):
        """Connects to all the switches in the topology."""
        for p4switch in self.topo.get_p4switches():
            thrift_port = self.topo.get_thrift_port(p4switch)
            self.controllers[p4switch] = SimpleSwitchAPI(thrift_port)

    def reset_states(self):
        """Resets registers, tables, etc."""
        for control in self.controllers.values():
            control.reset_state()

    def install_macs(self):
        """Install the port-to-mac map on all switches.

        You do not need to change this.

        Note: Real switches would rely on L2 learning to achieve this.
        """
        for switch, control in self.controllers.items():
            print "Installing MAC addresses for switch '%s'." % switch
            print "=========================================\n"
            for neighbor in self.topo.get_neighbors(switch):
                mac = self.topo.node_to_node_mac(neighbor, switch)
                port = self.topo.node_to_node_port_num(switch, neighbor)
                control.table_add('rewrite_mac', 'rewriteMac', [str(port)],
                                  [str(mac)])

    def install_nexthop_indices(self):
        """Install the mapping from prefix to nexthop ids for all switches."""
        for switch, control in self.controllers.items():
            print "Installing nexthop indices for switch '%s'." % switch
            print "===========================================\n"
            control.table_clear('ipv4_lpm')
            for host in self.topo.get_hosts():
                subnet = self.get_host_net(host)
                index = self.get_nexthop_index(host)
                control.table_add('ipv4_lpm', 'read_port', [subnet],
                                  [str(index)])

    def get_host_net(self, host):
        """Return ip and subnet of a host.

        Args:
            host (str): The host for which the net will be retruned.

        Returns:
            str: IP and subnet in the format "address/mask".
        """
        gateway = self.topo.get_host_gateway_name(host)
        return self.topo[host][gateway]['ip']

    def get_nexthop_index(self, host):
        """Return the nexthop index for a destination.

        Args:
            host (str): Name of destination node (host).

        Returns:
            int: nexthop index, used to look up nexthop ports.
        """
        # For now, give each host an individual nexthop id.
        host_list = sorted(list(self.topo.get_hosts().keys()))
        return host_list.index(host)

    def get_port(self, node, nexthop_node):
        """Return egress port for nexthop from the view of node.

        Args:
            node (str): Name of node for which the port is determined.
            nexthop_node (str): Name of node to reach.

        Returns:
            int: nexthop port
        """
        return self.topo.node_to_node_port_num(node, nexthop_node)

    def failure_notification(self, failures):
        """Called if a link fails.

        Args:
            failures (list(tuple(str, str))): List of failed links.
        """
        self.update_nexthops(failures=failures)

    # Helpers to update nexthops.
    # ===========================

    def dijkstra(self, failures=None):
        """Compute shortest paths and distances.

        Args:
            failures (list(tuple(str, str))): List of failed links.

        Returns:
            tuple(dict, dict): First dict: distances, second: paths.
        """
        graph = self.topo.network_graph

        if failures is not None:
            graph = graph.copy()
            for failure in failures:
                graph.remove_edge(*failure)

        # Compute the shortest paths from switches to hosts.
        dijkstra = dict(all_pairs_dijkstra(graph, weight='weight'))

        distances = {node: data[0] for node, data in dijkstra.items()}
        paths = {node: data[1] for node, data in dijkstra.items()}

        return distances, paths

    def compute_nexthops(self, failures=None):
        """Compute the best nexthops for all switches to each host.

        Optionally, a link can be marked as failed. This link will be excluded
        when computing the shortest paths.

        Args:
            failures (list(tuple(str, str))): List of failed links.

        Returns:
            dict(str, list(str, str, int))):
                Mapping from all switches to subnets, MAC, port.
        """
        # Compute the shortest paths from switches to hosts.
        all_shortest_paths = self.dijkstra(failures=failures)[1]

        # Translate shortest paths to mapping from host to nexthop node
        # (per switch).
        results = {}
        for switch in self.controllers:
            switch_results = results[switch] = []
            for host in self.topo.network_graph.get_hosts():
                try:
                    path = all_shortest_paths[switch][host]
                except KeyError:
                    print "WARNING: The graph is not connected!"
                    print "'%s' cannot reach '%s'." % (switch, host)
                    continue
                nexthop = path[1]  # path[0] is the switch itself.
                switch_results.append((host, nexthop))

        return results

    # Update nexthops.
    # ================

    def update_nexthops(self, failures=None):
        """Install nexthops in all switches."""
        nexthops = self.compute_nexthops(failures=failures)

        for switch, destinations in nexthops.items():
            print "Updating nexthops for switch '%s'." % switch
            control = self.controllers[switch]
            for host, nexthop in destinations:
                nexthop_id = self.get_nexthop_index(host)
                port = self.get_port(switch, nexthop)
                # Write the port in the nexthop lookup register.
                control.register_write('primaryNH', nexthop_id, port)

        #######################################################################
        # Compute loop-free alternate nexthops and install them below.
        #######################################################################

        pass
示例#3
0
class PacketLossController(object):

    def __init__(self, num_hashes=3):

        self.topo = Topology(db="topology.db")
        self.controllers = {}
        self.num_hashes = num_hashes

        # gets a controller API for each switch: {"s1": controller, "s2": controller...}
        self.connect_to_switches()
        # creates the 3 hashes that will use the p4 switch
        self.create_local_hashes()

        # initializes the switch
        # resets all registers, configures the 3 x 2 hash functions
        # reads the registers
        # populates the tables and mirroring id
        self.init()
        self.registers = {}

    def init(self):
        self.reset_all_registers()
        self.set_crc_custom_hashes_all()
        self.read_registers()
        self.configure_switches()

    def connect_to_switches(self):
        for p4switch in self.topo.get_p4switches():
            thrift_port = self.topo.get_thrift_port(p4switch)
            self.controllers[p4switch] = SimpleSwitchAPI(thrift_port)

    def configure_switches(self):

        for sw, controller in self.controllers.items():
            # ads cpu port
            controller.mirroring_add(100, 3)

            # set the basic forwarding rules
            controller.table_add("forwarding", "set_egress_port", ["1"], ["2"])
            controller.table_add("forwarding", "set_egress_port", ["2"], ["1"])

            # set the remove header rules when there is a host in a port
            direct_hosts = self.topo.get_hosts_connected_to(sw)
            for host in direct_hosts:
                port = self.topo.node_to_node_port_num(sw,host)
                controller.table_add("remove_loss_header", "remove_header", [str(port)], [])

    def set_crc_custom_hashes_all(self):
        for sw_name in self.controllers:
            self.set_crc_custom_hashes(sw_name)

    def set_crc_custom_hashes(self, sw_name):
        custom_calcs = sorted(self.controllers[sw_name].get_custom_crc_calcs().items())
        i = 0
        # Set the first 3 hashes for the um
        for custom_crc32, width in custom_calcs[:self.num_hashes]:
            self.controllers[sw_name].set_crc32_parameters(custom_crc32, crc32_polinomials[i], 0xffffffff, 0xffffffff, True,
                                                           True)
            i += 1

        i = 0
        # Sets the 3 hashes for the dm, they have to be the same, thus we use the same index
        for custom_crc32, width in custom_calcs[self.num_hashes:]:
            self.controllers[sw_name].set_crc32_parameters(custom_crc32, crc32_polinomials[i], 0xffffffff, 0xffffffff,
                                                           True, True)
            i += 1

    def create_local_hashes(self):
        self.hashes = []
        for i in range(self.num_hashes):
            self.hashes.append(Crc(32, crc32_polinomials[i], True, 0xffffffff, True, 0xffffffff))

    def reset_all_registers(self):
        for sw, controller in self.controllers.items():
            for register in controller.get_register_arrays():
                controller.register_reset(register)

    def reset_registers(self, sw, stream, port, batch_id):
        start = (batch_id * REGISTER_BATCH_SIZE) + ((port-1) * REGISTER_PORT_SIZE)
        end = start + REGISTER_PORT_SIZE

        for register in self.controllers[sw].get_register_arrays():
            if stream in register:
                self.controllers[sw].register_write(register, [start, end], 0)

    def flow_to_bytestream(self, flow):
        # flow fields are: srcip , dstip, srcport, dstport, protocol, ip id
        return socket.inet_aton(flow[0]) + socket.inet_aton(flow[1]) + struct.pack(">HHBH",flow[2], flow[3], flow[4], flow[5])

    def read_registers(self):
        # reads all the registers
        self.registers = {sw: {} for sw in self.controllers.keys()}
        for sw, controller in self.controllers.items():
            for register in controller.get_register_arrays():
                self.registers[sw][register] = (controller.register_read(register))

    def extract_register_information(self, sw, stream, port, batch_id):
        # reads the region of a um or dm register: uses port, batch id.
        start = (batch_id * REGISTER_BATCH_SIZE) + ((port-1) * REGISTER_PORT_SIZE)
        end = start + REGISTER_PORT_SIZE
        res = {}
        for name, values in self.registers[sw].items():
            if stream in name:
                res[name] = values[start:end]

        return res

    def decode_meter_pair(self, um_registers, dm_registers):

        # xor the registers
        counters = [x - y for x, y in zip(um_registers['MyIngress.um_counter'], dm_registers['MyIngress.dm_counter'])]
        ip_src = [x ^ y for x, y in zip(um_registers['MyIngress.um_ip_src'], dm_registers['MyIngress.dm_ip_src'])]
        ip_dst = [x ^ y for x, y in zip(um_registers['MyIngress.um_ip_dst'], dm_registers['MyIngress.dm_ip_dst'])]
        ports_proto_id = [x ^ y for x, y in zip(um_registers['MyIngress.um_ports_proto_id'], dm_registers['MyIngress.dm_ports_proto_id'])]
        dropped_packets = set()
        while 1 in counters:
            i = counters.index(1)
            tmp_src = ip_src[i]
            tmp_dst = ip_dst[i]
            src = socket.inet_ntoa(struct.pack("!I", tmp_src))
            dst = socket.inet_ntoa(struct.pack("!I", tmp_dst))
            misc = ports_proto_id[i]
            id  = misc & 0xffff
            proto = misc >> 16 & 0xff
            dst_port = misc >> 24 & 0xffff
            src_port = misc >> 40 & 0xffff
            flow = (src, dst, src_port, dst_port, proto, id)

            # get the three indexes
            flow_stream = self.flow_to_bytestream(flow)
            index0 = self.hashes[0].bit_by_bit_fast(flow_stream) % REGISTER_PORT_SIZE
            index1 = self.hashes[1].bit_by_bit_fast(flow_stream) % REGISTER_PORT_SIZE
            index2 = self.hashes[2].bit_by_bit_fast(flow_stream) % REGISTER_PORT_SIZE

            # clean this entries everywhere an continue
            counters[index0] -= 1
            counters[index1] -= 1
            counters[index2] -= 1

            ip_src[index0] ^= tmp_src
            ip_src[index1] ^= tmp_src
            ip_src[index2] ^= tmp_src

            ip_dst[index0] ^= tmp_dst
            ip_dst[index1] ^= tmp_dst
            ip_dst[index2] ^= tmp_dst

            ports_proto_id[index0] ^= misc
            ports_proto_id[index1] ^= misc
            ports_proto_id[index2] ^= misc

            # if there is a bad sync we skip this round
            # do not ask this in the readme
            # mainly the problem is the amount of buffer the switch allows
            if any(x < 0 for x in counters):
                return dropped_packets

            dropped_packets.add(flow)

        return dropped_packets


    def verify_link(self, sw1, sw2, batch_id):

        sw1_to_sw2_interface = self.topo.node_to_node_port_num(sw1, sw2)
        sw2_to_sw1_interface = self.topo.node_to_node_port_num(sw2, sw1)

        sw1_um = self.extract_register_information(sw1, 'um', sw1_to_sw2_interface, batch_id)
        sw2_dm = self.extract_register_information(sw2, 'dm', sw2_to_sw1_interface, batch_id)

        dropped_packets = self.decode_meter_pair(sw1_um, sw2_dm)

        # clean registers
        self.reset_registers(sw1, 'um', sw1_to_sw2_interface, batch_id)
        self.reset_registers(sw2, 'dm', sw2_to_sw1_interface, batch_id)

        # report
        if dropped_packets:
            print "Packets dropped: {} at link {}->{}:".format(len(dropped_packets), sw1, sw2)
            print "Details:"
            for packet in dropped_packets:
                print packet

    def check_sw_links(self, sw, batch_id):

        # just in case for the delay
        # increase decrease depending on the batch timeing
        time.sleep(0.25)

        # read all registers since its a small topo
        self.read_registers()

        # Process the right links and clean registers
        neighboring_p4_switches = [x for x in self.topo.get_neighbors(sw) if
                                   x in self.topo.get_p4switches()]

        for neighboring_switch in neighboring_p4_switches:
            self.verify_link(sw, neighboring_switch, batch_id)

    # When a batch_id changes the controller gets triggered
    def recv_msg_cpu(self, pkt):
        interface = pkt.sniffed_on
        print interface
        switch_name = interface.split("-")[0]
        packet = Ether(str(pkt))
        if packet.type == 0x1234:
            loss_header = LossHeader(packet.payload)
            batch_id = loss_header.batch_id >> 7
            print switch_name, batch_id
            self.check_sw_links(switch_name, batch_id)

    def run_cpu_port_loop(self):
        cpu_interfaces = [str(self.topo.get_cpu_port_intf(sw_name).replace("eth0", "eth1")) for sw_name in self.controllers]
        sniff(iface=cpu_interfaces, prn=self.recv_msg_cpu)