Beispiel #1
0
def ip2country(request):
    ip = str(request.REQUEST.get('ip'))
    
    geoip = GeoIP(settings.GEOIP, MEMORY_CACHE)
    c = geoip.country_name_by_addr(ip)
    c+="; "
    whois = os.popen("whois %s 2>&1" % ip)
    file.close
    for ln in whois:
        '''
        inetnum:      134.36.0.0 - 134.36.255.255
        descr:        University of Dundee
        descr:        Dundee DD1 4HN
        descr:        Scotland
        netname:      DUNDEE-UNIV
        descr:        University of Dundee
        country:      GB
        '''
        if ln.startswith("inetnum") or ln.startswith("netname") or ln.startswith("descr"):
            c+=ln.split(":")[1].strip()+"; "
        if ln.startswith("country"):
            c+=ln.split(":")[1].strip()+"."
            break
        if len(c) > 400:
            break
        
    return HttpResponse(c)
Beispiel #2
0
class GraphManager(object):
    """ Generates and processes the graph based on packets
    """
    def __init__(self,
                 packets,
                 layer=3,
                 geo_ip=os.path.expanduser('~/GeoIP.dat')):
        self.graph = DiGraph()
        self.layer = layer
        self.geo_ip = None
        self.data = {}

        try:
            self.geo_ip = GeoIP(geo_ip)
        except:
            logging.warning("could not load GeoIP data")

        if self.layer == 2:
            edges = map(self._layer_2_edge, packets)
        elif self.layer == 3:
            edges = map(self._layer_3_edge, packets)
        elif self.layer == 4:
            edges = map(self._layer_4_edge, packets)
        else:
            raise ValueError(
                "Other layers than 2,3 and 4 are not supported yet!")

        for src, dst, packet in filter(lambda x: not (x is None), edges):
            if src in self.graph and dst in self.graph[src]:
                self.graph[src][dst]['packets'].append(packet)
            else:
                self.graph.add_edge(src, dst)
                self.graph[src][dst]['packets'] = [packet]

        for node in self.graph.nodes():
            self._retrieve_node_info(node)

        for src, dst in self.graph.edges():
            self._retrieve_edge_info(src, dst)

    def get_in_degree(self, print_stdout=True):
        unsorted_degrees = self.graph.in_degree()
        return self._sorted_results(unsorted_degrees, print_stdout)

    def get_out_degree(self, print_stdout=True):
        unsorted_degrees = self.graph.out_degree()
        return self._sorted_results(unsorted_degrees, print_stdout)

    @staticmethod
    def _sorted_results(unsorted_degrees, print_stdout):
        sorted_degrees = OrderedDict(
            sorted(list(unsorted_degrees), key=lambda t: t[1], reverse=True))
        for i in sorted_degrees:
            if print_stdout:
                print(sorted_degrees[i], i)
        return sorted_degrees

    def _retrieve_node_info(self, node):
        self.data[node] = {}
        if self.layer >= 3 and self.geo_ip:
            if self.layer == 3:
                self.data[node]['ip'] = node
            elif self.layer == 4:
                self.data[node]['ip'] = node.split(':')[0]

            node_ip = self.data[node]['ip']
            try:
                country = self.geo_ip.country_name_by_addr(node_ip)
                self.data[node]['country'] = country if country else 'private'
            except:
                # it seems like we are not dealing with valid IPs...
                # best effort approach: skip
                del self.data[node]
        #TODO layer 2 info?

    def _retrieve_edge_info(self, src, dst):
        edge = self.graph[src][dst]
        if edge:
            packets = edge['packets']
            edge['layers'] = set(
                list(
                    itertools.chain(
                        *[set(GraphManager.get_layers(p)) for p in packets])))
            edge['transmitted'] = sum(len(p) for p in packets)
            edge['connections'] = len(packets)

    @staticmethod
    def get_layers(packet):
        return list(GraphManager.expand(packet))

    @staticmethod
    def expand(x):
        yield x.name
        while x.payload:
            x = x.payload
            yield x.name

    @staticmethod
    def _layer_2_edge(packet):
        return packet[0].src, packet[0].dst, packet

    @staticmethod
    def _layer_3_edge(packet):
        if packet.haslayer(IP):
            return packet[1].src, packet[1].dst, packet

    @staticmethod
    def _layer_4_edge(packet):
        if any(map(lambda p: packet.haslayer(p), [TCP, UDP])):
            src = packet[1].src
            dst = packet[1].dst
            _ = packet[2]
            return "%s:%i" % (src, _.sport), "%s:%i" % (dst, _.dport), packet

    def draw(self, filename=None):
        graph = self.get_graphviz_format()
        for node in graph.nodes():
            if node not in self.data:
                # node might be deleted, because it's not legit etc.
                continue
            node.attr['shape'] = 'circle'
            node.attr['fontsize'] = '10'
            node.attr['width'] = '0.5'
            if 'country' in self.data[str(node)]:
                country_label = self.data[str(node)]['country']
                if country_label == 'private':
                    node.attr['label'] = str(node)
                else:
                    node.attr['label'] = "%s (%s)" % (str(node), country_label)
                if not (country_label == 'private'):
                    node.attr['color'] = 'blue'
                    node.attr['style'] = 'filled'
                    #TODO add color based on country or scan?
        for edge in graph.edges():
            connection = self.graph[edge[0]][edge[1]]
            edge.attr['label'] = 'transmitted: %i bytes\n%s ' % (
                connection['transmitted'], ' | '.join(connection['layers']))
            edge.attr['fontsize'] = '8'
            edge.attr['minlen'] = '2'
            edge.attr['penwidth'] = min(
                connection['connections'] * 1.0 / len(self.graph.nodes()), 2.0)

        graph.layout(prog='dot')
        graph.draw(filename)

    def get_graphviz_format(self, filename=None):
        agraph = networkx.drawing.nx_agraph.to_agraph(self.graph)
        # remove packet information (blows up file size)
        for edge in agraph.edges():
            del edge.attr['packets']
        if filename:
            agraph.write(filename)
        return agraph
Beispiel #3
0
import re
import sys
from pygeoip import GeoIP

g = GeoIP('GeoIP.dat')

r = re.compile(r'20. ([0-9]+) "')

d = {}

while True:
    line = sys.stdin.readline()
    if not line: break
    match = r.search(line)
    if match:
        ip = line.split(' ')[0]
        country = g.country_name_by_addr(ip)
        if not country: country = 'Unknown'
        d[country] = int(match.group(1)) + d.get(country, 0)

ds = d.items()
ds.sort(lambda x, y: cmp(y[1], x[1]))
for k, v in ds:
    if v > (1024 * 1024 * 1024):
        print k, '%0.2fGb' % (v / (1024 * 1024 * 1024.0))
Beispiel #4
0
class QueryManager(object):
    """
    Implements handling multiple threads used to speed up serverqueries
    
    
    """

    def __init__(self):
        """
        Constructor -
        It starts with some basic initialisations and spawns a coordinator
        thread which creates more threads to perform the master server query 
        and also the status updates for the servers.
        """
        self.serverqueue = Queue()
        self.messageque = Queue()
        self.pulsemessageque = Queue() 
        
        self.threadcount = 0
        self.servercount = 0
        self.processedserver = 0
        self.filterdcount = 0
        
        self.gui_lock = None
        self.geo_lock = None
        
        coord = Thread(target=self.coordinator)
        coord.daemon = True
        coord.start()
        
        dbname = Globals.geoip_dir+ '/GeoIP.dat'
        self.pygeoip = GeoIP(dbname, pygeoip.const.MMAP_CACHE)
        
        self.abort = False
        
    def start_serverlist_refresh(self, liststore, tab):
        """
        Refreshes the Serverlist of a tab
        
        @param liststore - the liststore which contains the servers to be 
                           refreshed
        @param tab - the tab requesting the refresh
        """    
        self.tab = tab
        self.filter = None
        
        iter = liststore.iter_children(None)
        while iter:
            server = liststore.get_value(iter, 8)
            self.serverqueue.put(server)
            iter = liststore.iter_next(iter)
        
        self.servercount = self.serverqueue.qsize()
        
        gobject.idle_add(tab.clearServerList)
        self.messageque.put('serverlist_loaded')
        
    def startMasterServerQueryThread(self, filter, tab):
        """
        Starts the masterserver query.
        
        @param filter - filter to apply
        @param tab - tab requesting the serverlist
        
        """
        
        self.tab = tab
        self.filter = filter
        tab.clearServerList()
        
        
        #this message will cause the coordinator to start querying the master
        #server
        self.messageque.put('start_master_server_query')
        
    def startRecentServersLoadingThread(self, tab):
        """
        Starts loading the recent servers list
        
        @param tab - tab rquesting the recent servers
        """
        fm = FileManager()
        self.tab = tab
        self.filter = None
        
        serverdict = fm.getRecentServers()
        for key in serverdict:
            self.serverqueue.put(serverdict[key])
        
        self.servercount = len(serverdict)
        
        #notify the coordinator thread, that the serverlist is loaded
        self.messageque.put('serverlist_loaded')
        
        
    def startFavoritesLoadingThread(self, tab):
        """
        Starts loading the favorites
        
        @param tab - the tab requesting the favoriteslist
        """
        fm = FileManager()
        self.tab = tab
        self.filter = None
        
        serverlist = fm.getFavorites().values()
        for server in serverlist:
            self.serverqueue.put(server)
            
        self.servercount = len(serverlist)
        
        #notify the coordinator thread, that the serverlist is loaded
        self.messageque.put('serverlist_loaded')
     
    def lookup_server(self, server, tab):
        """
        Starts the lookup of a certain server. 
        
        @param server - the server to be looked up
        @param tab - the requesting tab
        """
        self.tab = tab
        self.filter = None
        
        self.serverqueue.put(server)
        self.servercount = 1
        self.messageque.put('serverlist_loaded')
        
    def coordinator(self):
        """
        Method that runs as coordinator thread.
        Spawning additional threads based on string messages in the messagueue
        
        Messages accepted: 
        start_master_server_query - is the start signal. will cause the coordinator
                                    to spawn two new threads. The first is to pulse
                                    the progressbar on self.tab every 0.1 seconds.
                                    The second thread performs the master server
                                    query, when this thread finishes puts the
                                    message serverlist_loaded into the 
                                    messagequeue indicating that the query
                                    was succesfull and the serverqueue is 
                                    filled with servers
        serverlist_loaded - spawns 10 worker threads that will perform the 
                            get_status request of the servers in the 
                            serverqueue
        finished - if the status of all servers has been retreived the last 
                   thread puts this message in the messagueue. calls the
                   serverlist_loading_finished method on self.tab and terminates
                   the coordinator thread
        
        
        """
        # main thread loop
        Log.log.debug('Thread:Coordinator started...')
        while True:
            try:
                
                message = self.messageque.get()    
                if message == 'start_master_server_query':
                    Log.log.info('Thread:Coordinator - start_master_server_' \
                                 +'query signal received')
                    # spawn the pulse progressbar thread
                    pt = Thread(target=self.pulse_progressbar_thread)
                    pt.setDaemon(True)
                    pt.start()
                    #spawns the master server query thread
                    pt = Thread(target=self.master_server_query_thread)
                    pt.setDaemon(True)
                    pt.start()
                elif message == 'serverlist_loaded':
                    Log.log.info('Thread:Coordinator - received serverlist' \
                                  +'_loaded signal. Queuesize is ' \
                                  + str(self.serverqueue.qsize()))
                    
                    #stop the pulsing of the progressbar
                    self.pulsemessageque.put('stop_pulse')    
                    #start 10 worker threads retreiving the status of 
                    #the servers in the serverqueue
                    for i in range(10):
                        name = 'Worker_' + str(i+1)
                        t = Thread(target=self.get_server_status_thread, name=name)
                        t.setDaemon(True)
                        t.start()
                elif message == 'finished':
                    #finish tasks :)
                    Log.log.info('Thread:Coordinator - received the ' \
                                  + 'finished signal')
                    self.gui_lock = threading.RLock()
                    with self.gui_lock:
                        gobject.idle_add(self.tab.serverlist_loading_finished)
                    break
                elif message == 'all_aborted':
                    #all_aborted tasks :)
                    Log.log.info('Thread:Coordinator - received the ' \
                                  + 'all_aborted signal')
                    self.gui_lock = threading.RLock()
                    with self.gui_lock:
                        gobject.idle_add(self.set_progressbar_aborted)
                        gobject.idle_add(self.tab.serverlist_loading_finished)
                    break
            except Empty:
                True
    
    
    def master_server_query_thread(self):
        """
        This method is running as a thread to retreive a list of servers
        from the master server.
        """
        query = Q3ServerQuery()
        
        
        empty = self.filter.show_empty   
        full = self.filter.show_full
        
        #query the urban terror master server
        serverlist = query.getServerList('master.urbanterror.info'
                                         ,27900
                                         ,empty
                                         ,full)
        #put all servers in the serverqueue
        for server in serverlist:
            self.serverqueue.put(server)
        self.servercount = len(serverlist)
        
        #notify the coordinator thread, that the serverlist is loaded
        self.messageque.put('serverlist_loaded')
      
      
    def get_server_status_thread(self):
        """
        This method will run as worker thread to retreive the status of
        the servers in the serverqueue
        """
        
        
        #increment thread count.
        #the counter will be decreased on exit and compared to 0
        #so the last thread can notify the coordinator that all threads finished
        #their work        
        self.gui_lock = threading.RLock()
        with self.gui_lock:
            self.threadcount+=1
            Log.log.debug('Thread:' + threading.current_thread().name + \
                         ' started') 
              
         
        # main thread loop
        while True:
            try:
                self.gui_lock = threading.RLock()
                with self.gui_lock:
                    if self.abort:
                        self.threadcount -= 1
                        Log.log.info('Thread:' + threading.current_thread().name + \
                         ' exiting due to abort signal')
                        if self.threadcount == 0: #last thread reached
                            Log.log.info('Thread:' + threading.current_thread().name + \
                            '   notifying the coordinator thread that all threads ' \
                            + 'was aborted')
                            self.messageque.put('all_aborted')
                        break
                
                server = self.serverqueue.get(False)    
                
                
                #perform the statusrequest
                query = Q3ServerQuery()   
                server = query.getServerStatus(server)
                
                #add the server to the gui 
                self.gui_lock = threading.RLock()
                with self.gui_lock:
                    
                    self.set_location(server)
                    
                    self.processedserver+=1
                    gobject.idle_add(self.set_progressbar_fraction)
                    if None == self.filter or \
                                   self.filter.does_filter_match_server(server):
                        gobject.idle_add(self.tab.addServer, server)
                    else:
                        self.filterdcount+=1  # server is not added but filterd
                        
            except Empty:
                #no more threads in the queue break thread execution
                self.gui_lock = threading.RLock()
                with self.gui_lock:
                    self.threadcount -= 1
                    Log.log.debug('Thread:' + threading.current_thread().name + \
                         ' finishes working and exiting')
                    if self.threadcount == 0: #last thread reached
                        Log.log.info('Thread:' + threading.current_thread().name + \
                         ' notifying the coordinator thread that the queue ' \
                         + 'processing is finished')
                        self.messageque.put('finished')
                break
    
    def pulse_progressbar_thread(self):
        """
        This method runs as a background thread that pulse the progressbar
        of self.tab every 0.1 seconds
        """
        while True:
            try:
                message = self.pulsemessageque.get(True, 0.1)    
                if message == 'stop_pulse':
                    break
            except Empty:
                self.gui_lock = threading.RLock()
                with self.gui_lock:
                    gobject.idle_add(self.pulse_progressbar)
                    
                    
    def set_progressbar_fraction(self):
        """
        Sets the progressbar fraction. Uses the total servercount and the
        processed servercount values to calculate the fraction
        """
        if not self.abort:
            fraction = float(self.processedserver) / float(self.servercount)
            
            
            bartext = None
            if 1.0 == fraction:
                bartext = 'finished getting server status - displaying ' \
                         + str((self.servercount-self.filterdcount)) + \
                         ' servers (' + str(self.filterdcount) + ' filtered)'
                self.tab.statusbar.progressbar.set_fraction(0.0)
                
            else:
                bartext = 'fetching server status (' + str(self.processedserver) + \
                          ' / ' + str(self.servercount) + ') - ' + \
                          str(self.filterdcount) + ' servers filtered'
                self.tab.statusbar.progressbar.set_fraction(fraction)     
            self.tab.statusbar.progressbar.set_text(bartext)
                    
    def pulse_progressbar(self):
        """
        Pulse the progressbar, called by the thread using  gobject.idle_add
        """
        self.tab.statusbar.progressbar.set_text('fetching serverlist from master server')
        self.tab.statusbar.progressbar.pulse() 
       
    def set_progressbar_aborted(self):
        """
        Sets the text of the progressbar to the aborted message and resets fraction
        """    
        self.tab.statusbar.progressbar.set_text('task aborted')
        self.tab.statusbar.progressbar.set_fraction(0.0)
    
    def abort_current_task(self):
        """
        Stops the processing of the queue by setting a abort flag.
        """    
        self.gui_lock = threading.RLock()
        with self.gui_lock:
            self.abort = True
        
    def set_location(self, server):
        """
        Determine location of a server based on the ip adress of the server 
        and set it at the server object
        
        Extra threading lock used because there was some strange effects 
        without it.
        
        @param - the server object
        """
        self.geo_lock = threading.RLock()
        with self.geo_lock:
            #location = country(server.getHost())
            location = self.pygeoip.country_code_by_addr(server.getHost())
            locname = self.pygeoip.country_name_by_addr(server.getHost())
            server.set_location(location)
            server.set_location_name(locname)
Beispiel #5
0
def index():
    res_data = {}
    verbose = False
    prettyprint = False

    try:
        client_ip = request.remote_route[0]
    except:
        client_ip = request.environ['REMOTE_ADDR']
        pass

    args = request.query_string.split('&')
    if 'v' in args:
        verbose = True

    if 'pp' in args:
        prettyprint = True

    res_data['ip'] = client_ip

    try:
        ua = request.environ['HTTP_USER_AGENT']
        user_agent = user_agents.parse(ua)
        res_data['user_agent'] = str(user_agent)
    except Exception as e:
        user_agent = False
        pass

    if user_agent and not verbose:
        if user_agent.browser.family == 'Other':
            response.set_header('Content-type', 'text/plain')
            return res_data['ip'] + '\r\n'

    try:
        ip = GeoIP(config.get('app', 'geoip_database'))
        res_data['country'] = ip.country_name_by_addr(client_ip)
        res_data['city'] = ip.record_by_addr(client_ip)
        res_data['asn'] = ip.asn_by_addr(client_ip)
    except Exception as e:
        pass

    try:
        name = reversename.from_address(client_ip)
        answers = resolver.query(name, 'PTR')
        res_data['hostname'] = []
        for answer in answers:
            res_data['hostname'].append(str(answer).rstrip('.'))
    except Exception as e:
        pass

    if user_agent and verbose:
        if user_agent.browser.family == 'Other':
            response.set_header('Content-type', 'application/json')
            if prettyprint:
                return json.dumps(
                    res_data,
                    sort_keys=True,
                    indent=4,
                    separators=(',', ': ')
                ) + '\r\n'
            return json.dumps(res_data)

    return template(
        'index',
        ua_info=res_data,
        page_title='Your IP-address is: ' + res_data['ip'],
        verbose=verbose
    )
Beispiel #6
0
def hit(request):
    stable_omero_downloads = 'http://downloads.openmicroscopy.org/latest-stable/omero'
    agent = None
    try:
        agt = request.META.get('HTTP_USER_AGENT', '')
        if agt is not None and agt.startswith("OMERO."):
            try:
                agent = Agent.objects.get(agent_name=agt)
            except Agent.DoesNotExist:
                return HttpResponseRedirect(UPGRADE_CHECK_URL)
            except:
                logger.error(traceback.format_exc())
                return HttpResponseRedirect(UPGRADE_CHECK_URL)
        else:
                return HttpResponseRedirect(UPGRADE_CHECK_URL)
    except:
        logger.error(traceback.format_exc())
        return HttpResponseRedirect(UPGRADE_CHECK_URL)
    logger.debug("Agent %s" % agent)
    
    agent_version = ''
    update = None
    try:
        agent_version = request.REQUEST.get('version')
        ver = Version.objects.get(pk=1)
        if agent_version is not None:
            try:
                regex = re.compile("^.*?[-]?(\\d+[.]\\d+([.]\\d+)?)[-]?.*?$")

                agent_cleaned = regex.match(agent_version).group(1)
                agent_split = agent_cleaned.split(".")

                local_cleaned = regex.match(ver.version).group(1)
                local_split = local_cleaned.split(".")

                rv = (agent_split < local_split)
            except:
                rv = True
            if rv:
                update = 'Please upgrade to %s. See %s for the latest version.' % (ver, stable_omero_downloads)
        else:
            update = 'Please upgrade to %s. See %s for the latest version.' % (ver, stable_omero_downloads)
    except:
        logger.debug(traceback.format_exc())
    logger.debug("Agent version %s" % agent_version)
    
    ip = None
    try:
        real_ip = None
        try:
            # HTTP_X_FORWARDED_FOR can be a comma-separated list of IPs. The
            # client's IP will be the first one.
            # http://code.djangoproject.com/ticket/3872
            real_ip = request.META['HTTP_X_FORWARDED_FOR']
            logger.debug("HTTP_X_FORWARDED_FOR: %s" % real_ip) 
            real_ip = real_ip.split(",")[-1].strip()
        except KeyError:
            real_ip = request.META.get('REMOTE_ADDR')
            
        if real_ip is not None:
            try:
                ip = IP.objects.get(ip=real_ip)
            except IP.DoesNotExist:
                latitude = None
                longitude = None
                country = None
                geoip = GeoIP(settings.GEODAT, STANDARD)
                gir = geoip.record_by_addr(real_ip)
                if gir is not None:
                    latitude = gir["latitude"]
                    longitude = gir["longitude"]
                geoip = GeoIP(settings.GEOIP, MEMORY_CACHE)
                country = geoip.country_name_by_addr(real_ip)
                    
                logger.debug("IP: %s, latitude: '%s', longitude: '%s'" % (real_ip, latitude, longitude))
                ip = IP(ip=real_ip, latitude=latitude, longitude=longitude, country=country)
                ip.save()
    except Exception, x:
        logger.debug(traceback.format_exc())
        raise x
Beispiel #7
0
class GraphManager(object):
    """ Generates and processes the graph based on packets
    """

    def __init__(self, packets, layer=3, geo_ip=os.path.expanduser('~/GeoIP.dat')):
        self.graph = DiGraph()
        self.layer = layer
        self.geo_ip = None
        self.data = {}

        try:
            self.geo_ip = GeoIP(geo_ip)
        except:
            logging.warning("could not load GeoIP data")

        if self.layer == 2:
            edges = map(self._layer_2_edge, packets)
        elif self.layer == 3:
            edges = map(self._layer_3_edge, packets)
        elif self.layer == 4:
            edges = map(self._layer_4_edge, packets)
        else:
            raise ValueError("Other layers than 2,3 and 4 are not supported yet!")

        for src, dst, packet in filter(lambda x: not (x is None), edges):
            if src in self.graph and dst in self.graph[src]:
                self.graph[src][dst]['packets'].append(packet)
            else:
                self.graph.add_edge(src, dst, {'packets': [packet]})

        for node in self.graph.nodes():
            self._retrieve_node_info(node)

        for src, dst in self.graph.edges():
            self._retrieve_edge_info(src, dst)

    def get_in_degree(self, print_stdout=True):
        unsorted_degrees = self.graph.in_degree()
        return self._sorted_results(unsorted_degrees, print_stdout)

    def get_out_degree(self, print_stdout=True):
        unsorted_degrees = self.graph.out_degree()
        return self._sorted_results(unsorted_degrees, print_stdout)

    @staticmethod
    def _sorted_results(unsorted_degrees, print_stdout):
        sorted_degrees = OrderedDict(sorted(unsorted_degrees.items(), key=lambda t: t[1], reverse=True))
        for i in sorted_degrees:
            if print_stdout:
                print(sorted_degrees[i], i)
        return sorted_degrees

    def _retrieve_node_info(self, node):
        self.data[node] = {}
        if self.layer >= 3 and self.geo_ip:
            if self.layer == 3:
                self.data[node]['ip'] = node
            elif self.layer == 4:
                self.data[node]['ip'] = node.split(':')[0]

            node_ip = self.data[node]['ip']
            country = self.geo_ip.country_name_by_addr(node_ip)
            self.data[node]['country'] = country if country else 'private'
        #TODO layer 2 info?

    def _retrieve_edge_info(self, src, dst):
        edge = self.graph[src][dst]
        if edge:
            packets = edge['packets']
            edge['layers'] = set(list(itertools.chain(*[set(GraphManager.get_layers(p)) for p in packets])))
            edge['transmitted'] = sum(len(p) for p in packets)
            edge['connections'] = len(packets)

    @staticmethod
    def get_layers(packet):
        return list(GraphManager.expand(packet))

    @staticmethod
    def expand(x):
        yield x.name
        while x.payload:
            x = x.payload
            yield x.name

    @staticmethod
    def _layer_2_edge(packet):
        return packet[0].src, packet[0].dst, packet

    @staticmethod
    def _layer_3_edge(packet):
        if packet.haslayer(IP):
            return packet[1].src, packet[1].dst, packet

    @staticmethod
    def _layer_4_edge(packet):
        if any(map(lambda p: packet.haslayer(p), [TCP, UDP])):
            src = packet[1].src
            dst = packet[1].dst
            _ = packet[2]
            return "%s:%i" % (src, _.sport), "%s:%i" % (dst, _.dport), packet

    def draw(self, filename=None, figsize=(50, 50)):
        graph = self.get_graphviz_format()

        for node in graph.nodes():
            node.attr['shape'] = 'circle'
            node.attr['fontsize'] = '10'
            node.attr['width'] = '0.5'
            if 'country' in self.data[str(node)]:
                country_label = self.data[str(node)]['country']
                if country_label == 'private':
                    node.attr['label'] = str(node)
                else:
                    node.attr['label'] = "%s (%s)" % (str(node), country_label)
                if not (country_label == 'private'):
                    node.attr['color'] = 'blue'
                    node.attr['style'] = 'filled'
                    #TODO add color based on country or scan?
        for edge in graph.edges():
            connection = self.graph[edge[0]][edge[1]]
            edge.attr['label'] = 'transmitted: %i bytes\n%s ' % (connection['transmitted'],  ' | '.join(connection['layers']))
            edge.attr['fontsize'] = '8'
            edge.attr['minlen'] = '2'
            edge.attr['penwidth'] = min(connection['connections'] * 1.0 / len(self.graph.nodes()), 2.0)

        graph.layout(prog='dot')
        graph.draw(filename)

    #TODO do we need a .dot file export?
    def get_graphviz_format(self, filename=None):
        agraph = networkx.to_agraph(self.graph)
        if filename:
            agraph.write(filename)
        return agraph