class RoutingController(object): def __init__(self): self.topo = Topology(db="topology.db") self.controllers = {} self.init() def init(self): self.connect_to_switches() self.reset_states() self.set_table_defaults() def reset_states(self): [controller.reset_state() for controller in self.controllers.values()] def connect_to_switches(self): for p4switch in self.topo.get_p4switches(): thrift_port = self.topo.get_thrift_port(p4switch) self.controllers[p4switch] = SimpleSwitchAPI(thrift_port) def set_table_defaults(self): for sw_name, controller in self.controllers.items(): controller.table_set_default("ipv4_lpm", "drop", []) controller.table_set_default("ecmp_group_to_nhop", "drop", []) controller.table_set_default( "sw_name", "set_swname", [str(int(sw_name.encode("hex"), base=16))]) def set_egress_type_table(self): for node, controller in self.controllers.items(): neighbor_list = self.topo.get_neighbors(node) for neighbor in neighbor_list: dst_type = 0 if (self.topo.is_host(neighbor)): dst_type = 1 else: dst_type = 2 port_num = self.topo.node_to_node_port_num(node, neighbor) controller.table_add("dst_type_table", "set_dst_type", [str(port_num)], [str(dst_type)]) def add_mirroring_ids(self): for node, controller in self.controllers.items(): controller.mirroring_add(100, 1) def route(self): switch_ecmp_groups = { sw_name: {} for sw_name in self.topo.get_p4switches().keys() } for sw_name, controller in self.controllers.items(): for sw_dst in self.topo.get_p4switches(): #if its ourselves we create direct connections if sw_name == sw_dst: for host in self.topo.get_hosts_connected_to(sw_name): sw_port = self.topo.node_to_node_port_num( sw_name, host) host_ip = self.topo.get_host_ip(host) + "/32" host_mac = self.topo.get_host_mac(host) #add rule print "table_add at {}:".format(sw_name) self.controllers[sw_name].table_add( "ipv4_lpm", "set_nhop", [str(host_ip)], [str(host_mac), str(sw_port)]) #check if there are directly connected hosts else: if self.topo.get_hosts_connected_to(sw_dst): paths = self.topo.get_shortest_paths_between_nodes( sw_name, sw_dst) for host in self.topo.get_hosts_connected_to(sw_dst): if len(paths) == 1: next_hop = paths[0][1] host_ip = self.topo.get_host_ip(host) + "/24" sw_port = self.topo.node_to_node_port_num( sw_name, next_hop) dst_sw_mac = self.topo.node_to_node_mac( next_hop, sw_name) #add rule print "table_add at {}:".format(sw_name) self.controllers[sw_name].table_add( "ipv4_lpm", "set_nhop", [str(host_ip)], [str(dst_sw_mac), str(sw_port)]) elif len(paths) > 1: next_hops = [x[1] for x in paths] dst_macs_ports = [ (self.topo.node_to_node_mac( next_hop, sw_name), self.topo.node_to_node_port_num( sw_name, next_hop)) for next_hop in next_hops ] host_ip = self.topo.get_host_ip(host) + "/24" #check if the ecmp group already exists. The ecmp group is defined by the number of next #ports used, thus we can use dst_macs_ports as key if switch_ecmp_groups[sw_name].get( tuple(dst_macs_ports), None): ecmp_group_id = switch_ecmp_groups[ sw_name].get(tuple(dst_macs_ports), None) print "table_add at {}:".format(sw_name) self.controllers[sw_name].table_add( "ipv4_lpm", "ecmp_group", [str(host_ip)], [ str(ecmp_group_id), str(len(dst_macs_ports)) ]) #new ecmp group for this switch else: new_ecmp_group_id = len( switch_ecmp_groups[sw_name]) + 1 switch_ecmp_groups[sw_name][tuple( dst_macs_ports)] = new_ecmp_group_id #add group for i, (mac, port) in enumerate(dst_macs_ports): print "table_add at {}:".format( sw_name) self.controllers[sw_name].table_add( "ecmp_group_to_nhop", "set_nhop", [str(new_ecmp_group_id), str(i)], [str(mac), str(port)]) #add forwarding rule print "table_add at {}:".format(sw_name) self.controllers[sw_name].table_add( "ipv4_lpm", "ecmp_group", [str(host_ip)], [ str(new_ecmp_group_id), str(len(dst_macs_ports)) ]) def main(self): self.set_egress_type_table() self.add_mirroring_ids() self.route()
class RerouteController(object): """Controller for the fast rerouting exercise.""" def __init__(self): """Initializes the topology and data structures.""" if not os.path.exists("topology.db"): print "Could not find topology object!\n" raise Exception self.topo = Topology(db="topology.db") self.controllers = {} self.connect_to_switches() self.reset_states() # Preconfigure all MAC addresses self.install_macs() # Install nexthop indices and populate registers. self.install_nexthop_indices() self.update_nexthops() def connect_to_switches(self): """Connects to all the switches in the topology.""" for p4switch in self.topo.get_p4switches(): thrift_port = self.topo.get_thrift_port(p4switch) self.controllers[p4switch] = SimpleSwitchAPI(thrift_port) def reset_states(self): """Resets registers, tables, etc.""" for control in self.controllers.values(): control.reset_state() def install_macs(self): """Install the port-to-mac map on all switches. You do not need to change this. Note: Real switches would rely on L2 learning to achieve this. """ for switch, control in self.controllers.items(): print "Installing MAC addresses for switch '%s'." % switch print "=========================================\n" for neighbor in self.topo.get_neighbors(switch): mac = self.topo.node_to_node_mac(neighbor, switch) port = self.topo.node_to_node_port_num(switch, neighbor) control.table_add('rewrite_mac', 'rewriteMac', [str(port)], [str(mac)]) def install_nexthop_indices(self): """Install the mapping from prefix to nexthop ids for all switches.""" for switch, control in self.controllers.items(): print "Installing nexthop indices for switch '%s'." % switch print "===========================================\n" control.table_clear('ipv4_lpm') for host in self.topo.get_hosts(): subnet = self.get_host_net(host) index = self.get_nexthop_index(host) control.table_add('ipv4_lpm', 'read_port', [subnet], [str(index)]) def get_host_net(self, host): """Return ip and subnet of a host. Args: host (str): The host for which the net will be retruned. Returns: str: IP and subnet in the format "address/mask". """ gateway = self.topo.get_host_gateway_name(host) return self.topo[host][gateway]['ip'] def get_nexthop_index(self, host): """Return the nexthop index for a destination. Args: host (str): Name of destination node (host). Returns: int: nexthop index, used to look up nexthop ports. """ # For now, give each host an individual nexthop id. host_list = sorted(list(self.topo.get_hosts().keys())) return host_list.index(host) def get_port(self, node, nexthop_node): """Return egress port for nexthop from the view of node. Args: node (str): Name of node for which the port is determined. nexthop_node (str): Name of node to reach. Returns: int: nexthop port """ return self.topo.node_to_node_port_num(node, nexthop_node) def failure_notification(self, failures): """Called if a link fails. Args: failures (list(tuple(str, str))): List of failed links. """ self.update_nexthops(failures=failures) # Helpers to update nexthops. # =========================== def dijkstra(self, failures=None): """Compute shortest paths and distances. Args: failures (list(tuple(str, str))): List of failed links. Returns: tuple(dict, dict): First dict: distances, second: paths. """ graph = self.topo.network_graph if failures is not None: graph = graph.copy() for failure in failures: graph.remove_edge(*failure) # Compute the shortest paths from switches to hosts. dijkstra = dict(all_pairs_dijkstra(graph, weight='weight')) distances = {node: data[0] for node, data in dijkstra.items()} paths = {node: data[1] for node, data in dijkstra.items()} return distances, paths def compute_nexthops(self, failures=None): """Compute the best nexthops for all switches to each host. Optionally, a link can be marked as failed. This link will be excluded when computing the shortest paths. Args: failures (list(tuple(str, str))): List of failed links. Returns: dict(str, list(str, str, int))): Mapping from all switches to subnets, MAC, port. """ # Compute the shortest paths from switches to hosts. all_shortest_paths = self.dijkstra(failures=failures)[1] # Translate shortest paths to mapping from host to nexthop node # (per switch). results = {} for switch in self.controllers: switch_results = results[switch] = [] for host in self.topo.network_graph.get_hosts(): try: path = all_shortest_paths[switch][host] except KeyError: print "WARNING: The graph is not connected!" print "'%s' cannot reach '%s'." % (switch, host) continue nexthop = path[1] # path[0] is the switch itself. switch_results.append((host, nexthop)) return results # Update nexthops. # ================ def update_nexthops(self, failures=None): """Install nexthops in all switches.""" nexthops = self.compute_nexthops(failures=failures) for switch, destinations in nexthops.items(): print "Updating nexthops for switch '%s'." % switch control = self.controllers[switch] for host, nexthop in destinations: nexthop_id = self.get_nexthop_index(host) port = self.get_port(switch, nexthop) # Write the port in the nexthop lookup register. control.register_write('primaryNH', nexthop_id, port) ####################################################################### # Compute loop-free alternate nexthops and install them below. ####################################################################### pass
class PacketLossController(object): def __init__(self, num_hashes=3): self.topo = Topology(db="topology.db") self.controllers = {} self.num_hashes = num_hashes # gets a controller API for each switch: {"s1": controller, "s2": controller...} self.connect_to_switches() # creates the 3 hashes that will use the p4 switch self.create_local_hashes() # initializes the switch # resets all registers, configures the 3 x 2 hash functions # reads the registers # populates the tables and mirroring id self.init() self.registers = {} def init(self): self.reset_all_registers() self.set_crc_custom_hashes_all() self.read_registers() self.configure_switches() def connect_to_switches(self): for p4switch in self.topo.get_p4switches(): thrift_port = self.topo.get_thrift_port(p4switch) self.controllers[p4switch] = SimpleSwitchAPI(thrift_port) def configure_switches(self): for sw, controller in self.controllers.items(): # ads cpu port controller.mirroring_add(100, 3) # set the basic forwarding rules controller.table_add("forwarding", "set_egress_port", ["1"], ["2"]) controller.table_add("forwarding", "set_egress_port", ["2"], ["1"]) # set the remove header rules when there is a host in a port direct_hosts = self.topo.get_hosts_connected_to(sw) for host in direct_hosts: port = self.topo.node_to_node_port_num(sw,host) controller.table_add("remove_loss_header", "remove_header", [str(port)], []) def set_crc_custom_hashes_all(self): for sw_name in self.controllers: self.set_crc_custom_hashes(sw_name) def set_crc_custom_hashes(self, sw_name): custom_calcs = sorted(self.controllers[sw_name].get_custom_crc_calcs().items()) i = 0 # Set the first 3 hashes for the um for custom_crc32, width in custom_calcs[:self.num_hashes]: self.controllers[sw_name].set_crc32_parameters(custom_crc32, crc32_polinomials[i], 0xffffffff, 0xffffffff, True, True) i += 1 i = 0 # Sets the 3 hashes for the dm, they have to be the same, thus we use the same index for custom_crc32, width in custom_calcs[self.num_hashes:]: self.controllers[sw_name].set_crc32_parameters(custom_crc32, crc32_polinomials[i], 0xffffffff, 0xffffffff, True, True) i += 1 def create_local_hashes(self): self.hashes = [] for i in range(self.num_hashes): self.hashes.append(Crc(32, crc32_polinomials[i], True, 0xffffffff, True, 0xffffffff)) def reset_all_registers(self): for sw, controller in self.controllers.items(): for register in controller.get_register_arrays(): controller.register_reset(register) def reset_registers(self, sw, stream, port, batch_id): start = (batch_id * REGISTER_BATCH_SIZE) + ((port-1) * REGISTER_PORT_SIZE) end = start + REGISTER_PORT_SIZE for register in self.controllers[sw].get_register_arrays(): if stream in register: self.controllers[sw].register_write(register, [start, end], 0) def flow_to_bytestream(self, flow): # flow fields are: srcip , dstip, srcport, dstport, protocol, ip id return socket.inet_aton(flow[0]) + socket.inet_aton(flow[1]) + struct.pack(">HHBH",flow[2], flow[3], flow[4], flow[5]) def read_registers(self): # reads all the registers self.registers = {sw: {} for sw in self.controllers.keys()} for sw, controller in self.controllers.items(): for register in controller.get_register_arrays(): self.registers[sw][register] = (controller.register_read(register)) def extract_register_information(self, sw, stream, port, batch_id): # reads the region of a um or dm register: uses port, batch id. start = (batch_id * REGISTER_BATCH_SIZE) + ((port-1) * REGISTER_PORT_SIZE) end = start + REGISTER_PORT_SIZE res = {} for name, values in self.registers[sw].items(): if stream in name: res[name] = values[start:end] return res def decode_meter_pair(self, um_registers, dm_registers): # xor the registers counters = [x - y for x, y in zip(um_registers['MyIngress.um_counter'], dm_registers['MyIngress.dm_counter'])] ip_src = [x ^ y for x, y in zip(um_registers['MyIngress.um_ip_src'], dm_registers['MyIngress.dm_ip_src'])] ip_dst = [x ^ y for x, y in zip(um_registers['MyIngress.um_ip_dst'], dm_registers['MyIngress.dm_ip_dst'])] ports_proto_id = [x ^ y for x, y in zip(um_registers['MyIngress.um_ports_proto_id'], dm_registers['MyIngress.dm_ports_proto_id'])] dropped_packets = set() while 1 in counters: i = counters.index(1) tmp_src = ip_src[i] tmp_dst = ip_dst[i] src = socket.inet_ntoa(struct.pack("!I", tmp_src)) dst = socket.inet_ntoa(struct.pack("!I", tmp_dst)) misc = ports_proto_id[i] id = misc & 0xffff proto = misc >> 16 & 0xff dst_port = misc >> 24 & 0xffff src_port = misc >> 40 & 0xffff flow = (src, dst, src_port, dst_port, proto, id) # get the three indexes flow_stream = self.flow_to_bytestream(flow) index0 = self.hashes[0].bit_by_bit_fast(flow_stream) % REGISTER_PORT_SIZE index1 = self.hashes[1].bit_by_bit_fast(flow_stream) % REGISTER_PORT_SIZE index2 = self.hashes[2].bit_by_bit_fast(flow_stream) % REGISTER_PORT_SIZE # clean this entries everywhere an continue counters[index0] -= 1 counters[index1] -= 1 counters[index2] -= 1 ip_src[index0] ^= tmp_src ip_src[index1] ^= tmp_src ip_src[index2] ^= tmp_src ip_dst[index0] ^= tmp_dst ip_dst[index1] ^= tmp_dst ip_dst[index2] ^= tmp_dst ports_proto_id[index0] ^= misc ports_proto_id[index1] ^= misc ports_proto_id[index2] ^= misc # if there is a bad sync we skip this round # do not ask this in the readme # mainly the problem is the amount of buffer the switch allows if any(x < 0 for x in counters): return dropped_packets dropped_packets.add(flow) return dropped_packets def verify_link(self, sw1, sw2, batch_id): sw1_to_sw2_interface = self.topo.node_to_node_port_num(sw1, sw2) sw2_to_sw1_interface = self.topo.node_to_node_port_num(sw2, sw1) sw1_um = self.extract_register_information(sw1, 'um', sw1_to_sw2_interface, batch_id) sw2_dm = self.extract_register_information(sw2, 'dm', sw2_to_sw1_interface, batch_id) dropped_packets = self.decode_meter_pair(sw1_um, sw2_dm) # clean registers self.reset_registers(sw1, 'um', sw1_to_sw2_interface, batch_id) self.reset_registers(sw2, 'dm', sw2_to_sw1_interface, batch_id) # report if dropped_packets: print "Packets dropped: {} at link {}->{}:".format(len(dropped_packets), sw1, sw2) print "Details:" for packet in dropped_packets: print packet def check_sw_links(self, sw, batch_id): # just in case for the delay # increase decrease depending on the batch timeing time.sleep(0.25) # read all registers since its a small topo self.read_registers() # Process the right links and clean registers neighboring_p4_switches = [x for x in self.topo.get_neighbors(sw) if x in self.topo.get_p4switches()] for neighboring_switch in neighboring_p4_switches: self.verify_link(sw, neighboring_switch, batch_id) # When a batch_id changes the controller gets triggered def recv_msg_cpu(self, pkt): interface = pkt.sniffed_on print interface switch_name = interface.split("-")[0] packet = Ether(str(pkt)) if packet.type == 0x1234: loss_header = LossHeader(packet.payload) batch_id = loss_header.batch_id >> 7 print switch_name, batch_id self.check_sw_links(switch_name, batch_id) def run_cpu_port_loop(self): cpu_interfaces = [str(self.topo.get_cpu_port_intf(sw_name).replace("eth0", "eth1")) for sw_name in self.controllers] sniff(iface=cpu_interfaces, prn=self.recv_msg_cpu)