예제 #1
0
def test_format_endpoints():
    data = [[{
        'ip-state': 'foo'
    }, {
        'ip-state': 'bar'
    }], [{
        'ip-state': 'foo'
    }, {
        'ip-state': 'bar'
    }]]
    output = FaucetProxy.format_endpoints(data)
예제 #2
0
def test_format_endpoints():
    data = [[{
        'ip-state': 'foo'
    }, {
        'ip-state': 'bar'
    }],
            [{
                'ip-state': 'foo',
                'ip-address': '0.0.0.0'
            }, {
                'ip-state': 'bar',
                'ip-address': '::1'
            }]]
    output = FaucetProxy.format_endpoints(data, 'foo')
예제 #3
0
class SDNConnect(object):
    def __init__(self, controller):
        self.controller = controller
        self.r = None
        self.first_time = True
        self.sdnc = None
        trunk_ports = self.controller['trunk_ports']
        if isinstance(trunk_ports, str):
            self.trunk_ports = json.loads(trunk_ports)
        else:
            self.trunk_ports = trunk_ports
        self.logger = logger
        self.get_sdn_context()
        self.endpoints = {}
        self.investigations = 0
        self.clear_filters()
        self.redis_lock = threading.Lock()
        self.connect_redis()
        self.default_endpoints()

    def clear_filters(self):
        ''' clear any exisiting filters. '''
        if isinstance(self.sdnc, FaucetProxy):
            Parser().clear_mirrors(self.controller['CONFIG_FILE'])
        elif isinstance(self.sdnc, BcfProxy):
            self.logger.debug('removing bcf filter rules')
            retval = self.sdnc.remove_filter_rules()
            self.logger.debug('removed filter rules: {0}'.format(retval))

    def default_endpoints(self):
        ''' set endpoints to default state. '''
        self.get_stored_endpoints()
        for endpoint in self.endpoints.values():
            if not endpoint.ignore:
                if endpoint.state != 'inactive':
                    if endpoint.state == 'mirroring':
                        endpoint.p_next_state = 'mirror'
                    elif endpoint.state == 'reinvestigating':
                        endpoint.p_next_state = 'reinvestigate'
                    elif endpoint.state == 'queued':
                        endpoint.p_next_state = 'queue'
                    elif endpoint.state in ['known', 'abnormal']:
                        endpoint.p_next_state = endpoint.state
                    endpoint.endpoint_data['active'] = 0
                    endpoint.inactive()
                    endpoint.p_prev_states.append(
                        (endpoint.state, int(time.time())))
        self.store_endpoints()

    def get_stored_endpoints(self):
        ''' load existing endpoints from Redis. '''
        with self.redis_lock:
            endpoints = {}
            if self.r:
                try:
                    p_endpoints = self.r.get('p_endpoints')
                    if p_endpoints:
                        p_endpoints = ast.literal_eval(
                            p_endpoints.decode('ascii'))
                        for p_endpoint in p_endpoints:
                            endpoint = EndpointDecoder(
                                p_endpoint).get_endpoint()
                            endpoints[endpoint.name] = endpoint
                except Exception as e:  # pragma: no cover
                    self.logger.error(
                        'Unable to get existing endpoints from Redis because {0}'
                        .format(str(e)))
            self.endpoints = endpoints
        return

    def get_stored_metadata(self, hash_id):
        mac_addresses = {}
        ip_addresses = {ip_field: {} for ip_field in MACHINE_IP_FIELDS}

        if self.r:
            macs = []
            try:
                macs = self.r.smembers('mac_addresses')
            except Exception as e:  # pragma: no cover
                self.logger.error(
                    'Unable to get existing mac addresses from Redis because: {0}'
                    .format(str(e)))
            for mac in macs:
                try:
                    mac_info = self.r.hgetall(mac)
                    if b'poseidon_hash' in mac_info and mac_info[
                            b'poseidon_hash'] == hash_id.encode('utf-8'):
                        mac_addresses[mac.decode('ascii')] = {}
                        if b'timestamps' in mac_info:
                            try:
                                timestamps = ast.literal_eval(
                                    mac_info[b'timestamps'].decode('ascii'))
                                for timestamp in timestamps:
                                    ml_info = self.r.hgetall(
                                        mac.decode('ascii') + '_' +
                                        str(timestamp))
                                    labels = []
                                    if b'labels' in ml_info:
                                        labels = ast.literal_eval(
                                            ml_info[b'labels'].decode('ascii'))
                                    confidences = []
                                    if b'confidences' in ml_info:
                                        confidences = ast.literal_eval(
                                            ml_info[b'confidences'].decode(
                                                'ascii'))
                                    behavior = 'None'
                                    tmp = []
                                    if mac_info[b'poseidon_hash'] in ml_info:
                                        tmp = ast.literal_eval(
                                            ml_info[mac_info[b'poseidon_hash']]
                                            .decode('ascii'))
                                    elif mac_info[b'poseidon_hash'].decode(
                                            'ascii') in ml_info:
                                        tmp = ast.literal_eval(ml_info[
                                            mac_info[b'poseidon_hash'].decode(
                                                'ascii')].decode('ascii'))
                                    if 'decisions' in tmp and 'behavior' in tmp[
                                            'decisions']:
                                        behavior = tmp['decisions']['behavior']
                                    mac_addresses[mac.decode('ascii')][str(
                                        timestamp)] = {
                                            'labels': labels,
                                            'confidences': confidences,
                                            'behavior': behavior
                                        }
                            except Exception as e:  # pragma: no cover
                                self.logger.error(
                                    'Unable to get existing ML data from Redis because: {0}'
                                    .format(str(e)))
                        try:
                            poseidon_info = self.r.hgetall(
                                mac_info[b'poseidon_hash'])
                            if b'endpoint_data' in poseidon_info:
                                endpoint_data = ast.literal_eval(
                                    poseidon_info[b'endpoint_data'].decode(
                                        'ascii'))
                                for ip_field in MACHINE_IP_FIELDS:
                                    try:
                                        raw_field = endpoint_data.get(
                                            ip_field, None)
                                        machine_ip = ipaddress.ip_address(
                                            raw_field)
                                    except ValueError:
                                        machine_ip = ''
                                    if machine_ip:
                                        try:
                                            ip_info = self.r.hgetall(raw_field)
                                            short_os = ip_info.get(
                                                b'short_os', None)
                                            ip_addresses[ip_field][
                                                raw_field] = {}
                                            if short_os:
                                                ip_addresses[ip_field][
                                                    raw_field][
                                                        'os'] = short_os.decode(
                                                            'ascii')
                                        except Exception as e:  # pragma: no cover
                                            self.logger.error(
                                                'Unable to get existing {0} data from Redis because: {1}'
                                                .format(ip_field, str(e)))
                        except Exception as e:  # pragma: no cover
                            self.logger.error(
                                'Unable to get existing endpoint data from Redis because: {0}'
                                .format(str(e)))
                except Exception as e:  # pragma: no cover
                    self.logger.error(
                        'Unable to get existing metadata for {0} from Redis because: {1}'
                        .format(mac, str(e)))
        return mac_addresses, ip_addresses['ipv4'], ip_addresses['ipv6']

    def get_sdn_context(self):
        if 'TYPE' in self.controller and self.controller['TYPE'] == 'bcf':
            try:
                self.sdnc = BcfProxy(self.controller)
            except Exception as e:  # pragma: no cover
                self.logger.error(
                    'BcfProxy could not connect to {0} because {1}'.format(
                        self.controller['URI'], e))
        elif 'TYPE' in self.controller and self.controller['TYPE'] == 'faucet':
            try:
                self.sdnc = FaucetProxy(self.controller)
            except Exception as e:  # pragma: no cover
                self.logger.error(
                    'FaucetProxy could not connect to {0} because {1}'.format(
                        self.controller['URI'], e))
        elif 'TYPE' in self.controller and self.controller['TYPE'] == 'None':
            self.sdnc = None
        else:
            if 'CONTROLLER_PASS' in self.controller:
                self.controller['CONTROLLER_PASS'] = '******'
            self.logger.error('Unknown SDN controller config: {0}'.format(
                self.controller))

    def endpoint_by_name(self, name):
        return self.endpoints.get(name, None)

    def endpoint_by_hash(self, hash_id):
        return self.endpoint_by_name(hash_id)

    def endpoints_by_ip(self, ip):
        endpoints = [
            endpoint for endpoint in self.endpoints.values()
            if ip == endpoint.endpoint_data.get('ipv4', None)
            or ip == endpoint.endpoint_data.get('ipv6', None)
        ]
        return endpoints

    def endpoints_by_mac(self, mac):
        endpoints = [
            endpoint for endpoint in self.endpoints.values()
            if mac == endpoint.endpoint_data['mac']
        ]
        return endpoints

    @staticmethod
    def _connect_rabbit():
        # Rabbit settings
        exchange = 'topic-poseidon-internal'
        exchange_type = 'topic'

        # Starting rabbit connection
        connection = pika.BlockingConnection(
            pika.ConnectionParameters(host='RABBIT_SERVER'))

        channel = connection.channel()
        channel.exchange_declare(exchange=exchange,
                                 exchange_type=exchange_type)

        return channel, exchange, connection

    @staticmethod
    def publish_action(action, message):
        try:
            channel, exchange, connection = SDNConnect._connect_rabbit()
            channel.basic_publish(exchange=exchange,
                                  routing_key=action,
                                  body=message)
            connection.close()
        except Exception as e:  # pragma: no cover
            pass
        return

    def show_endpoints(self, arg):
        endpoints = []
        if arg == 'all':
            endpoints = list(self.endpoints.values())
        else:
            show_type, arg = arg.split(' ', 1)
            for endpoint in self.endpoints.values():
                if show_type == 'state':
                    if arg == 'active' and endpoint.state != 'inactive':
                        endpoints.append(endpoint)
                    elif arg == 'ignored' and endpoint.ignore:
                        endpoints.append(endpoint)
                    elif endpoint.state == arg:
                        endpoints.append(endpoint)
                elif show_type in ['os', 'behavior', 'role']:
                    # filter by device type or behavior
                    if 'mac_addresses' in endpoint.metadata and endpoint.endpoint_data[
                            'mac'] in endpoint.metadata['mac_addresses']:
                        timestamps = endpoint.metadata['mac_addresses'][
                            endpoint.endpoint_data['mac']]
                        newest = '0'
                        for timestamp in timestamps:
                            if timestamp > newest:
                                newest = timestamp
                        if newest is not '0':
                            if 'labels' in timestamps[newest]:
                                if arg.replace(
                                        '-', ' '
                                ) == timestamps[newest]['labels'][0].lower():
                                    endpoints.append(endpoint)
                            if 'behavior' in timestamps[newest]:
                                if arg == timestamps[newest]['behavior'].lower(
                                ):
                                    endpoints.append(endpoint)

                    # filter by operating system
                    for ip_field in MACHINE_IP_FIELDS:
                        ip_addresses_field = '_'.join((ip_field, 'addresses'))
                        ip_addresses = endpoint.metadata.get(
                            ip_addresses_field, None)
                        machine_ip = endpoint.endpoint_data.get(ip_field, None)
                        if machine_ip and ip_addresses and machine_ip in ip_addresses:
                            metadata = ip_addresses[machine_ip]
                            os = metadata.get('os', None)
                            if os and os.lower() == arg:
                                endpoints.append(endpoint)
        return endpoints

    def check_endpoints(self, messages=None):
        if not self.sdnc:
            return

        retval = {}
        retval['machines'] = None
        retval['resp'] = 'bad'

        current = None
        parsed = None

        try:
            current = self.sdnc.get_endpoints(messages=messages)
            parsed = self.sdnc.format_endpoints(current,
                                                self.controller['URI'])
            retval['machines'] = parsed
            retval['resp'] = 'ok'
        except Exception as e:  # pragma: no cover
            self.logger.error(
                'Could not establish connection to {0} because {1}.'.format(
                    self.controller['URI'], e))
            retval[
                'controller'] = 'Could not establish connection to {0}.'.format(
                    self.controller['URI'])

        self.find_new_machines(parsed)

        return

    def connect_redis(self, host='redis', port=6379, db=0):
        self.r = None
        try:
            self.r = StrictRedis(host=host,
                                 port=port,
                                 db=db,
                                 socket_connect_timeout=2)
        except Exception as e:  # pragma: no cover
            self.logger.error('Failed connect to Redis because: {0}'.format(
                str(e)))
        return

    @staticmethod
    def _diff_machine(machine_a, machine_b):
        def _machine_strlines(machine):
            return str(json.dumps(machine, indent=2)).splitlines()

        machine_a_strlines = _machine_strlines(machine_a)
        machine_b_strlines = _machine_strlines(machine_b)
        return '\n'.join(
            difflib.unified_diff(machine_a_strlines, machine_b_strlines, n=1))

    @staticmethod
    def _parse_machine_ip(machine):
        machine_ip_data = {}
        for ip_field, fields in MACHINE_IP_FIELDS.items():
            try:
                raw_field = machine.get(ip_field, None)
                machine_ip = ipaddress.ip_address(raw_field)
                machine_subnet = ipaddress.ip_network(machine_ip).supernet(
                    new_prefix=MACHINE_IP_PREFIXES[ip_field])
            except ValueError:
                machine_ip = None
                machine_subnet = None
            machine_ip_data[ip_field] = ''
            if machine_ip:
                machine_ip_data.update({
                    ip_field:
                    str(machine_ip),
                    '_'.join((ip_field, 'rdns')):
                    get_rdns_lookup(str(machine_ip)),
                    '_'.join((ip_field, 'subnet')):
                    str(machine_subnet)
                })
            for field in fields:
                if field not in machine_ip_data:
                    machine_ip_data[field] = NO_DATA
        return machine_ip_data

    @staticmethod
    def merge_machine_ip(old_machine, new_machine):
        for ip_field, fields in MACHINE_IP_FIELDS.items():
            ip = new_machine.get(ip_field, None)
            old_ip = old_machine.get(ip_field, None)
            if not ip and old_ip:
                new_machine[ip_field] = old_ip
                for field in fields:
                    if field in old_machine:
                        new_machine[field] = old_machine[field]

    def find_new_machines(self, machines):
        '''parse switch structure to find new machines added to network
        since last call'''
        change_acls = False

        for machine in machines:
            machine['ether_vendor'] = get_ether_vendor(
                machine['mac'],
                '/poseidon/poseidon/metadata/nmap-mac-prefixes.txt')
            machine.update(self._parse_machine_ip(machine))
            if not 'controller_type' in machine:
                machine.update({'controller_type': 'none', 'controller': ''})
            trunk = False
            for sw in self.trunk_ports:
                if sw == machine['segment'] and self.trunk_ports[sw].split(
                        ',')[1] == str(
                            machine['port']) and self.trunk_ports[sw].split(
                                ',')[0] == machine['mac']:
                    trunk = True

            h = Endpoint.make_hash(machine, trunk=trunk)
            ep = self.endpoints.get(h, None)
            if ep is None:
                change_acls = True
                m = endpoint_factory(h)
                m.p_prev_states.append((m.state, int(time.time())))
                m.endpoint_data = deepcopy(machine)
                self.endpoints[m.name] = m
                self.logger.info('Detected new endpoint: {0}:{1}'.format(
                    m.name, machine))
            else:
                self.merge_machine_ip(ep.endpoint_data, machine)

            if ep and ep.endpoint_data != machine and not ep.ignore:
                diff_txt = self._diff_machine(ep.endpoint_data, machine)
                self.logger.info('Endpoint changed: {0}:{1}'.format(
                    h, diff_txt))
                change_acls = True
                ep.endpoint_data = deepcopy(machine)
                if ep.state == 'inactive' and machine['active'] == 1:
                    if ep.p_next_state in ['known', 'abnormal']:
                        ep.trigger(ep.p_next_state)
                    else:
                        ep.unknown()
                    ep.p_prev_states.append((ep.state, int(time.time())))
                elif ep.state != 'inactive' and machine['active'] == 0:
                    if ep.state in ['mirroring', 'reinvestigating']:
                        status = Actions(ep, self.sdnc).unmirror_endpoint()
                        if not status:
                            self.logger.warning(
                                'Unable to unmirror the endpoint: {0}'.format(
                                    ep.name))
                        if ep.state == 'mirroring':
                            ep.p_next_state = 'mirror'
                        elif ep.state == 'reinvestigating':
                            ep.p_next_state = 'reinvestigate'
                    if ep.state in ['known', 'abnormal']:
                        ep.p_next_state = ep.state
                    ep.inactive()
                    ep.p_prev_states.append((ep.state, int(time.time())))

        if change_acls and self.controller['AUTOMATED_ACLS']:
            status = Actions(None, self.sdnc).update_acls(
                rules_file=self.controller['RULES_FILE'],
                endpoints=self.endpoints.values())
            if isinstance(status, list):
                self.logger.info(
                    'Automated ACLs did the following: {0}'.format(status[1]))
                for item in status[1]:
                    machine = {
                        'mac': item[1],
                        'segment': item[2],
                        'port': item[3]
                    }
                    h = Endpoint.make_hash(machine)
                    ep = self.endpoints.get(h, None)
                    if ep:
                        ep.acl_data.append((item[0], item[4], item[5]),
                                           int(time.time()))
        self.store_endpoints()
        self.get_stored_endpoints()

    def store_endpoints(self):
        ''' store current endpoints in Redis. '''
        with self.redis_lock:
            if self.r:
                try:
                    serialized_endpoints = []
                    for endpoint in self.endpoints.values():
                        # set metadata
                        mac_addresses, ipv4_addresses, ipv6_addresses = self.get_stored_metadata(
                            str(endpoint.name))

                        #list of fields to make history entries for, along with entry type for that field
                        fields = [
                            {
                                'field_name': 'behavior',
                                'entry_type': HistoryTypes.PROPERTY_CHANGE
                            },
                            {
                                'field_name': 'ipv4_OS',
                                'entry_type': HistoryTypes.PROPERTY_CHANGE
                            },
                            {
                                'field_name': 'ipv6_OS',
                                'entry_type': HistoryTypes.PROPERTY_CHANGE
                            },
                        ]

                        #make history entries for any changed prop
                        prior = None
                        for timestamp in mac_addresses:
                            for field in fields:
                                record = mac_addresses[timestamp]
                                if field['field_name'] in record and prior and field['field_name'] in prior and \
                                   prior[field['field_name']] != record[field['field_name']]:
                                    endpoint.update_property_history(
                                        field['entry_type'],
                                        field['field_name'],
                                        endpoint.endpoint_data.
                                        mac_addresses['field_name'],
                                        record[field['field_name']])
                                prior = record

                        prior = None
                        for timestamp in ipv4_addresses:
                            for field in fields:
                                record = ipv4_addresses[timestamp]
                                if field['field_name'] in record and prior and field['field_name'] in prior and \
                                   prior[field['field_name']] != record[field['field_name']]:
                                    endpoint.update_property_history(
                                        field['entry_type'],
                                        field['field_name'],
                                        endpoint.endpoint_data.
                                        ipv4_addresses['field_name'],
                                        record[field['field_name']])
                                prior = record

                        prior = None
                        for timestamp in ipv6_addresses:
                            for field in fields:
                                record = ipv6_addresses[timestamp]
                                if field['field_name'] in record and prior and field['field_name'] in prior and \
                                   prior[field['field_name']] != record[field['field_name']]:
                                    endpoint.update_property_history(
                                        field['entry_type'],
                                        field['field_name'],
                                        endpoint.endpoint_data.
                                        ipv6_addresses['field_name'],
                                        record[field['field_name']])
                                prior = record

                        endpoint.metadata = {
                            'mac_addresses': mac_addresses,
                            'ipv4_addresses': ipv4_addresses,
                            'ipv6_addresses': ipv6_addresses
                        }
                        redis_endpoint_data = {
                            'name': str(endpoint.name),
                            'state': str(endpoint.state),
                            'ignore': str(endpoint.ignore),
                            'endpoint_data': str(endpoint.endpoint_data),
                            'next_state': str(endpoint.p_next_state),
                            'prev_states': str(endpoint.p_prev_states),
                            'acl_data': str(endpoint.acl_data),
                            'metadata': str(endpoint.metadata),
                        }
                        self.r.hmset(endpoint.name, redis_endpoint_data)
                        mac = endpoint.endpoint_data['mac']
                        self.r.hmset(mac,
                                     {'poseidon_hash': str(endpoint.name)})
                        if not self.r.sismember('mac_addresses', mac):
                            self.r.sadd('mac_addresses', mac)
                        for ip_field in MACHINE_IP_FIELDS:
                            try:
                                machine_ip = ipaddress.ip_address(
                                    endpoint.endpoint_data.get(ip_field, None))
                            except ValueError:
                                machine_ip = None
                            if machine_ip:
                                self.r.hmset(
                                    str(machine_ip),
                                    {'poseidon_hash': str(endpoint.name)})
                                if not self.r.sismember(
                                        'ip_addresses', str(machine_ip)):
                                    self.r.sadd('ip_addresses',
                                                str(machine_ip))
                        serialized_endpoints.append(endpoint.encode())
                    self.r.set('p_endpoints', str(serialized_endpoints))
                except Exception as e:  # pragma: no cover
                    self.logger.error(
                        'Unable to store endpoints in Redis because {0}'.
                        format(str(e)))
예제 #4
0
class SDNConnect(object):
    def __init__(self):
        self.r = None
        self.first_time = True
        self.sdnc = None
        self.controller = Config().get_config()
        trunk_ports = self.controller['trunk_ports']
        if isinstance(trunk_ports, str):
            self.trunk_ports = json.loads(trunk_ports)
        else:
            self.trunk_ports = trunk_ports
        self.logger = logger
        self.get_sdn_context()
        self.endpoints = []
        self.investigations = 0
        self.connect_redis()

    def get_stored_endpoints(self):
        # load existing endpoints if any
        if self.r:
            try:
                p_endpoints = self.r.get('p_endpoints')
                if p_endpoints:
                    p_endpoints = ast.literal_eval(p_endpoints.decode('ascii'))
                    self.endpoints = []
                    for endpoint in p_endpoints:
                        self.endpoints.append(
                            EndpointDecoder(endpoint).get_endpoint())
            except Exception as e:  # pragma: no cover
                self.logger.error(
                    'Unable to get existing endpoints from Redis because {0}'.
                    format(str(e)))
        return

    def get_stored_metadata(self, hash_id):
        mac_addresses = {}
        ipv4_addresses = {}
        ipv6_addresses = {}
        if self.r:
            macs = []
            try:
                macs = self.r.smembers('mac_addresses')
            except Exception as e:  # pragma: no cover
                self.logger.error(
                    'Unable to get existing mac addresses from Redis because: {0}'
                    .format(str(e)))
            for mac in macs:
                try:
                    mac_info = self.r.hgetall(mac)
                    if b'poseidon_hash' in mac_info and mac_info[
                            b'poseidon_hash'] == hash_id.encode('utf-8'):
                        mac_addresses[mac.decode('ascii')] = {}
                        if b'timestamps' in mac_info:
                            try:
                                timestamps = ast.literal_eval(
                                    mac_info[b'timestamps'].decode('ascii'))
                                for timestamp in timestamps:
                                    ml_info = self.r.hgetall(
                                        mac.decode('ascii') + '_' +
                                        str(timestamp))
                                    labels = []
                                    if b'labels' in ml_info:
                                        labels = ast.literal_eval(
                                            ml_info[b'labels'].decode('ascii'))
                                    confidences = []
                                    if b'confidences' in ml_info:
                                        confidences = ast.literal_eval(
                                            ml_info[b'confidences'].decode(
                                                'ascii'))
                                    behavior = 'None'
                                    tmp = []
                                    if mac_info[b'poseidon_hash'] in ml_info:
                                        tmp = ast.literal_eval(
                                            ml_info[mac_info[b'poseidon_hash']]
                                            .decode('ascii'))
                                    elif mac_info[b'poseidon_hash'].decode(
                                            'ascii') in ml_info:
                                        tmp = ast.literal_eval(ml_info[
                                            mac_info[b'poseidon_hash'].decode(
                                                'ascii')].decode('ascii'))
                                    if 'decisions' in tmp and 'behavior' in tmp[
                                            'decisions']:
                                        behavior = tmp['decisions']['behavior']
                                    mac_addresses[mac.decode('ascii')][str(
                                        timestamp)] = {
                                            'labels': labels,
                                            'confidences': confidences,
                                            'behavior': behavior
                                        }
                            except Exception as e:  # pragma: no cover
                                self.logger.error(
                                    'Unable to get existing ML data from Redis because: {0}'
                                    .format(str(e)))
                        try:
                            poseidon_info = self.r.hgetall(
                                mac_info[b'poseidon_hash'])
                            if b'endpoint_data' in poseidon_info:
                                endpoint_data = ast.literal_eval(
                                    poseidon_info[b'endpoint_data'].decode(
                                        'ascii'))
                                if 'ipv4' in endpoint_data and endpoint_data[
                                        'ipv4'] not in ['None', 0]:
                                    try:
                                        ipv4_info = self.r.hgetall(
                                            endpoint_data['ipv4'])
                                        ipv4_addresses[
                                            endpoint_data['ipv4']] = {}
                                        if ipv4_info and b'short_os' in ipv4_info:
                                            ipv4_addresses[endpoint_data[
                                                'ipv4']]['os'] = ipv4_info[
                                                    b'short_os'].decode(
                                                        'ascii')
                                    except Exception as e:  # pragma: no cover
                                        self.logger.error(
                                            'Unable to get existing ipv4 data from Redis because: {0}'
                                            .format(str(e)))
                                if 'ipv6' in endpoint_data and endpoint_data[
                                        'ipv6'] not in ['None', 0]:
                                    try:
                                        ipv6_info = self.r.hgetall(
                                            endpoint_data['ipv6'])
                                        ipv6_addresses[
                                            endpoint_data['ipv6']] = {}
                                        if ipv6_info and b'short_os' in ipv6_info:
                                            ipv6_addresses[endpoint_data[
                                                'ipv6']]['os'] = ipv6_info[
                                                    b'short_os'].decode(
                                                        'ascii')
                                    except Exception as e:  # pragma: no cover
                                        self.logger.error(
                                            'Unable to get existing ipv6 data from Redis because: {0}'
                                            .format(str(e)))
                        except Exception as e:  # pragma: no cover
                            self.logger.error(
                                'Unable to get existing endpoint data from Redis because: {0}'
                                .format(str(e)))
                except Exception as e:  # pragma: no cover
                    self.logger.error(
                        'Unable to get existing metadata for {0} from Redis because: {1}'
                        .format(mac, str(e)))
        return mac_addresses, ipv4_addresses, ipv6_addresses

    def get_sdn_context(self):
        if 'TYPE' in self.controller and self.controller['TYPE'] == 'bcf':
            try:
                self.sdnc = BcfProxy(self.controller)
            except Exception as e:  # pragma: no cover
                self.logger.error(
                    'BcfProxy could not connect to {0} because {1}'.format(
                        self.controller['URI'], e))
        elif 'TYPE' in self.controller and self.controller['TYPE'] == 'faucet':
            try:
                self.sdnc = FaucetProxy(self.controller)
            except Exception as e:  # pragma: no cover
                self.logger.error(
                    'FaucetProxy could not connect to {0} because {1}'.format(
                        self.controller['URI'], e))
        elif 'TYPE' in self.controller and self.controller['TYPE'] == 'None':
            self.sdnc = None
        else:
            if 'CONTROLLER_PASS' in self.controller:
                self.controller['CONTROLLER_PASS'] = '******'
            self.logger.error('Unknown SDN controller config: {0}'.format(
                self.controller))

    def endpoint_by_name(self, name):
        for endpoint in self.endpoints:
            if endpoint.machine.name.strip() == name:
                return endpoint
        return None

    def endpoint_by_hash(self, hash_id):
        for endpoint in self.endpoints:
            if endpoint.name == hash_id:
                return endpoint
        return None

    def endpoints_by_ip(self, ip):
        endpoints = []
        for endpoint in self.endpoints:
            if (('ipv4' in endpoint.endpoint_data
                 and ip == endpoint.endpoint_data['ipv4'])
                    or ('ipv6' in endpoint.endpoint_data
                        and ip == endpoint.endpoint_data['ipv6'])):
                endpoints.append(endpoint)
        return endpoints

    def endpoints_by_mac(self, mac):
        endpoints = []
        for endpoint in self.endpoints:
            if mac == endpoint.endpoint_data['mac']:
                endpoints.append(endpoint)
        return endpoints

    @staticmethod
    def _connect_rabbit():
        # Rabbit settings
        exchange = 'topic-poseidon-internal'
        exchange_type = 'topic'

        # Starting rabbit connection
        connection = pika.BlockingConnection(
            pika.ConnectionParameters(host='RABBIT_SERVER'))

        channel = connection.channel()
        channel.exchange_declare(exchange=exchange,
                                 exchange_type=exchange_type)

        return channel, exchange, connection

    @staticmethod
    def publish_action(action, message):
        try:
            channel, exchange, connection = SDNConnect._connect_rabbit()
            channel.basic_publish(exchange=exchange,
                                  routing_key=action,
                                  body=message)
            connection.close()
        except Exception as e:  # pragma: no cover
            pass
        return

    def show_endpoints(self, arg):
        endpoints = []
        if arg == 'all':
            for endpoint in self.endpoints:
                endpoints.append(endpoint)
        else:
            show_type, arg = arg.split(' ', 1)
            for endpoint in self.endpoints:
                if show_type == 'state':
                    if arg == 'active' and endpoint.state != 'inactive':
                        endpoints.append(endpoint)
                    elif arg == 'ignored' and endpoint.ignore:
                        endpoints.append(endpoint)
                    elif endpoint.state == arg:
                        endpoints.append(endpoint)
                elif show_type in ['os', 'behavior', 'role']:
                    # filter by device type or behavior
                    if 'mac_addresses' in endpoint.metadata and endpoint.endpoint_data[
                            'mac'] in endpoint.metadata['mac_addresses']:
                        timestamps = endpoint.metadata['mac_addresses'][
                            endpoint.endpoint_data['mac']]
                        newest = '0'
                        for timestamp in timestamps:
                            if timestamp > newest:
                                newest = timestamp
                        if newest is not '0':
                            if 'labels' in timestamps[newest]:
                                if arg.replace(
                                        '-', ' '
                                ) == timestamps[newest]['labels'][0].lower():
                                    endpoints.append(endpoint)
                            if 'behavior' in timestamps[newest]:
                                if arg == timestamps[newest]['behavior'].lower(
                                ):
                                    endpoints.append(endpoint)

                    # filter by operating system
                    if 'ipv4_addresses' in endpoint.metadata and endpoint.endpoint_data[
                            'ipv4'] in endpoint.metadata['ipv4_addresses']:
                        metadata = endpoint.metadata['ipv4_addresses'][
                            endpoint.endpoint_data['ipv4']]
                        if 'os' in metadata:
                            if arg == metadata['os'].lower():
                                endpoints.append(endpoint)
                    if 'ipv6_addresses' in endpoint.metadata and endpoint.endpoint_data[
                            'ipv6'] in endpoint.metadata['ipv6_addresses']:
                        metadata = endpoint.metadata['ipv6_addresses'][
                            endpoint.endpoint_data['ipv6']]
                        if 'os' in metadata:
                            if arg == metadata['os'].lower():
                                endpoints.append(endpoint)
        return endpoints

    def check_endpoints(self, messages=None):
        if not self.sdnc:
            return

        retval = {}
        retval['machines'] = None
        retval['resp'] = 'bad'

        current = None
        parsed = None

        try:
            current = self.sdnc.get_endpoints(messages=messages)
            parsed = self.sdnc.format_endpoints(current,
                                                self.controller['URI'])
            retval['machines'] = parsed
            retval['resp'] = 'ok'
        except Exception as e:  # pragma: no cover
            self.logger.error(
                'Could not establish connection to {0} because {1}.'.format(
                    self.controller['URI'], e))
            retval[
                'controller'] = 'Could not establish connection to {0}.'.format(
                    self.controller['URI'])

        self.find_new_machines(parsed)

        return

    def connect_redis(self, host='redis', port=6379, db=0):
        self.r = None
        try:
            self.r = StrictRedis(host=host,
                                 port=port,
                                 db=db,
                                 socket_connect_timeout=2)
        except Exception as e:  # pragma: no cover
            self.logger.error('Failed connect to Redis because: {0}'.format(
                str(e)))
        return

    def find_new_machines(self, machines):
        '''parse switch structure to find new machines added to network
        since last call'''
        for machine in machines:
            machine['ether_vendor'] = get_ether_vendor(
                machine['mac'],
                '/poseidon/poseidon/metadata/nmap-mac-prefixes.txt')
            if 'ipv4' in machine and machine['ipv4'] and machine[
                    'ipv4'] is not 'None' and machine['ipv4'] is not '0':
                machine['ipv4_rdns'] = get_rdns_lookup(machine['ipv4'])
                machine['ipv4_subnet'] = '.'.join(
                    machine['ipv4'].split('.')[:-1]) + '.0/24'
            else:
                machine['ipv4_rdns'] = 'NO DATA'
                machine['ipv4_subnet'] = 'NO DATA'
            if 'ipv6' in machine and machine['ipv6'] and machine[
                    'ipv6'] is not 'None' and machine['ipv6'] is not '0':
                machine['ipv6_rdns'] = get_rdns_lookup(machine['ipv6'])
                machine['ipv6_subnet'] = '.'.join(
                    machine['ipv6'].split(':')[0:4]) + '::0/64'
            else:
                machine['ipv6_rdns'] = 'NO DATA'
                machine['ipv6_subnet'] = 'NO DATA'
            if not 'controller_type' in machine:
                machine['controller_type'] = 'none'
                machine['controller'] = ''
            trunk = False
            for sw in self.trunk_ports:
                if sw == machine['segment'] and self.trunk_ports[sw].split(
                        ',')[1] == str(
                            machine['port']) and self.trunk_ports[sw].split(
                                ',')[0] == machine['mac']:
                    trunk = True

            h = Endpoint.make_hash(machine, trunk=trunk)
            ep = None
            for endpoint in self.endpoints:
                if h == endpoint.name:
                    ep = endpoint
            if ep is not None and ep.endpoint_data != machine and not ep.ignore:
                self.logger.info('Endpoint changed: {0}:{1}'.format(
                    h, machine))
                ep.endpoint_data = deepcopy(machine)
                if ep.state == 'inactive' and machine['active'] == 1:
                    if ep.p_next_state in ['known', 'abnormal']:
                        ep.trigger(ep.p_next_state)
                    else:
                        ep.unknown()
                    ep.p_prev_states.append((ep.state, int(time.time())))
                elif ep.state != 'inactive' and machine['active'] == 0:
                    if ep.state in ['mirroring', 'reinvestigating']:
                        status = Actions(ep, self.sdnc).unmirror_endpoint()
                        if not status:
                            self.logger.warning(
                                'Unable to unmirror the endpoint: {0}'.format(
                                    ep.name))
                        self.investigations -= 1
                        if ep.state == 'mirroring':
                            ep.p_next_state = 'mirror'
                        elif ep.state == 'reinvestigating':
                            ep.p_next_state = 'reinvestigate'
                    if ep.state in ['known', 'abnormal']:
                        ep.p_next_state = ep.state
                    ep.inactive()
                    ep.p_prev_states.append((ep.state, int(time.time())))
            elif ep is None:
                self.logger.info('Detected new endpoint: {0}:{1}'.format(
                    h, machine))
                m = Endpoint(h)
                m.p_prev_states.append((m.state, int(time.time())))
                m.endpoint_data = deepcopy(machine)
                self.endpoints.append(m)

        self.store_endpoints()
        return

    def store_endpoints(self):
        # store latest version of endpoints in redis
        if self.r:
            try:
                serialized_endpoints = []
                for endpoint in self.endpoints:
                    # set metadata
                    mac_addresses, ipv4_addresses, ipv6_addresses = self.get_stored_metadata(
                        str(endpoint.name))
                    endpoint.metadata = {
                        'mac_addresses': mac_addresses,
                        'ipv4_addresses': ipv4_addresses,
                        'ipv6_addresses': ipv6_addresses
                    }
                    redis_endpoint_data = {}
                    redis_endpoint_data['name'] = str(endpoint.name)
                    redis_endpoint_data['state'] = str(endpoint.state)
                    redis_endpoint_data['ignore'] = str(endpoint.ignore)
                    redis_endpoint_data['endpoint_data'] = str(
                        endpoint.endpoint_data)
                    redis_endpoint_data['next_state'] = str(
                        endpoint.p_next_state)
                    redis_endpoint_data['prev_states'] = str(
                        endpoint.p_prev_states)
                    redis_endpoint_data['metadata'] = str(endpoint.metadata)
                    self.r.hmset(endpoint.name, redis_endpoint_data)
                    mac = endpoint.endpoint_data['mac']
                    self.r.hmset(mac, {'poseidon_hash': str(endpoint.name)})
                    self.r.sadd('mac_addresses', mac)
                    if 'ipv4' in endpoint.endpoint_data and endpoint.endpoint_data[
                            'ipv4'] != 'None' and endpoint.endpoint_data[
                                'ipv4']:
                        self.r.hmset(endpoint.endpoint_data['ipv4'],
                                     {'poseidon_hash': str(endpoint.name)})
                        self.r.sadd('ip_addresses',
                                    endpoint.endpoint_data['ipv4'])
                    if 'ipv6' in endpoint.endpoint_data and endpoint.endpoint_data[
                            'ipv6'] != 'None' and endpoint.endpoint_data[
                                'ipv6']:
                        self.r.hmset(endpoint.endpoint_data['ipv6'],
                                     {'poseidon_hash': str(endpoint.name)})
                        self.r.sadd('ip_addresses',
                                    endpoint.endpoint_data['ipv6'])
                    serialized_endpoints.append(endpoint.encode())
                self.r.set('p_endpoints', str(serialized_endpoints))
            except Exception as e:  # pragma: no cover
                self.logger.error(
                    'Unable to store endpoints in Redis because {0}'.format(
                        str(e)))
        return
예제 #5
0
파일: main.py 프로젝트: x0rzkov/poseidon
class SDNConnect:
    def __init__(self, controller, first_time=True):
        self.controller = controller
        self.r = None
        self.first_time = first_time
        self.sdnc = None
        self.endpoints = {}
        trunk_ports = self.controller['trunk_ports']
        if isinstance(trunk_ports, str):
            self.trunk_ports = json.loads(trunk_ports)
        else:
            self.trunk_ports = trunk_ports
        self.logger = logger
        self.get_sdn_context()
        self.prc = PoseidonRedisClient(self.logger)
        self.prc.connect()
        if self.first_time:
            self.endpoints = {}
            self.investigations = 0
            self.coprocessing = 0
            self.clear_filters()
            self.default_endpoints()

    def mirror_endpoint(self, endpoint):
        ''' mirror an endpoint. '''
        status = Actions(endpoint, self.sdnc).mirror_endpoint()
        if status:
            self.prc.inc_network_tools_counts()
        else:
            self.logger.warning('Unable to mirror the endpoint: {0}'.format(
                endpoint.name))

    def unmirror_endpoint(self, endpoint):
        ''' unmirror an endpoint. '''
        status = Actions(endpoint, self.sdnc).unmirror_endpoint()
        if not status:
            self.logger.warning('Unable to unmirror the endpoint: {0}'.format(
                endpoint.name))

    def clear_filters(self):
        ''' clear any exisiting filters. '''
        if isinstance(self.sdnc, FaucetProxy):
            self.sdnc.clear_mirrors()

    def default_endpoints(self):
        ''' set endpoints to default state. '''
        self.get_stored_endpoints()
        for endpoint in self.endpoints.values():
            if not endpoint.ignore:
                if endpoint.state != 'inactive':
                    if endpoint.state == 'mirroring':
                        endpoint.p_next_state = 'mirror'
                    elif endpoint.state == 'reinvestigating':
                        endpoint.p_next_state = 'reinvestigate'
                    elif endpoint.state == 'queued':
                        endpoint.p_next_state = 'queue'
                    elif endpoint.state in ['known', 'abnormal']:
                        endpoint.p_next_state = endpoint.state
                    endpoint.endpoint_data['active'] = 0
                    endpoint.inactive()  # pytype: disable=attribute-error
                    endpoint.p_prev_states.append(
                        (endpoint.state, int(time.time())))
        self.store_endpoints()

    def get_stored_endpoints(self):
        ''' load existing endpoints from Redis. '''
        new_endpoints = self.prc.get_stored_endpoints()
        if new_endpoints:
            self.endpoints = new_endpoints

    def get_sdn_context(self):
        controller_type = self.controller.get('TYPE', None)
        if controller_type == 'faucet':
            self.sdnc = FaucetProxy(self.controller)
        elif controller_type == 'None':
            self.sdnc = None
        else:
            self.logger.error('Unknown SDN controller config: {0}'.format(
                self.controller))

    def endpoint_by_name(self, name):
        return self.endpoints.get(name, None)

    def endpoint_by_hash(self, hash_id):
        return self.endpoint_by_name(hash_id)

    def endpoints_by_ip(self, ip):
        endpoints = [
            endpoint for endpoint in self.endpoints.values()
            if ip == endpoint.endpoint_data.get('ipv4', None)
            or ip == endpoint.endpoint_data.get('ipv6', None)
        ]
        return endpoints

    def endpoints_by_mac(self, mac):
        endpoints = [
            endpoint for endpoint in self.endpoints.values()
            if mac == endpoint.endpoint_data['mac']
        ]
        return endpoints

    @staticmethod
    def _connect_rabbit():
        # Rabbit settings
        exchange = 'topic-poseidon-internal'
        exchange_type = 'topic'

        # Starting rabbit connection
        connection = pika.BlockingConnection(
            pika.ConnectionParameters(host='RABBIT_SERVER'))

        channel = connection.channel()
        channel.exchange_declare(exchange=exchange,
                                 exchange_type=exchange_type)

        return channel, exchange, connection

    @staticmethod
    def publish_action(action, message):
        try:
            channel, exchange, connection = SDNConnect._connect_rabbit()
            channel.basic_publish(exchange=exchange,
                                  routing_key=action,
                                  body=message)
            connection.close()
        except Exception as e:  # pragma: no cover
            print(str(e))

    def show_endpoints(self, arg):
        endpoints = []
        if arg == 'all':
            endpoints = list(self.endpoints.values())
        else:
            show_type, arg = arg.split(' ', 1)
            for endpoint in self.endpoints.values():
                if show_type == 'state':
                    if arg == 'active' and endpoint.state != 'inactive':
                        endpoints.append(endpoint)
                    elif arg == 'ignored' and endpoint.ignore:
                        endpoints.append(endpoint)
                    elif endpoint.state == arg:
                        endpoints.append(endpoint)
                elif show_type in ['os', 'behavior', 'role']:
                    mac_addresses = endpoint.metadata.get(
                        'mac_addresses', None)
                    endpoint_mac = endpoint.endpoint_data['mac']
                    if endpoint_mac and mac_addresses and endpoint_mac in mac_addresses:
                        timestamps = mac_addresses[endpoint_mac]
                        try:
                            newest = sorted(
                                [timestamp for timestamp in timestamps])[-1]
                            newest = timestamps[newest]
                        except IndexError:
                            newest = None
                        if newest:
                            if 'labels' in newest:
                                if arg.replace(
                                        '-',
                                        ' ') == newest['labels'][0].lower():
                                    endpoints.append(endpoint)
                            if 'behavior' in newest:
                                if arg == newest['behavior'].lower():
                                    endpoints.append(endpoint)

                    # filter by operating system
                    for ip_field in MACHINE_IP_FIELDS:
                        ip_addresses_field = '_'.join((ip_field, 'addresses'))
                        ip_addresses = endpoint.metadata.get(
                            ip_addresses_field, None)
                        machine_ip = endpoint.endpoint_data.get(ip_field, None)
                        if machine_ip and ip_addresses and machine_ip in ip_addresses:
                            metadata = ip_addresses[machine_ip]
                            os = metadata.get('os', None)
                            if os and os.lower() == arg:
                                endpoints.append(endpoint)
        return endpoints

    def check_endpoints(self, messages=None):
        if not self.sdnc:
            return

        retval = {}
        retval['machines'] = None
        retval['resp'] = 'bad'

        current = None
        parsed = None

        try:
            current = self.sdnc.get_endpoints(messages=messages)
            parsed = self.sdnc.format_endpoints(current)
            retval['machines'] = parsed
            retval['resp'] = 'ok'
        except Exception as e:  # pragma: no cover
            self.logger.error(
                'Could not establish connection to controller because {0}.'.
                format(e))
            retval[
                'controller'] = 'Could not establish connection to controller'

        self.find_new_machines(parsed)

    @staticmethod
    def _diff_machine(machine_a, machine_b):
        def _machine_strlines(machine):
            return str(json.dumps(machine, indent=2)).splitlines()

        machine_a_strlines = _machine_strlines(machine_a)
        machine_b_strlines = _machine_strlines(machine_b)
        return '\n'.join(
            difflib.unified_diff(machine_a_strlines, machine_b_strlines, n=1))

    @staticmethod
    def _parse_machine_ip(machine):
        machine_ip_data = {}
        for ip_field, fields in MACHINE_IP_FIELDS.items():
            try:
                raw_field = machine.get(ip_field, None)
                machine_ip = ipaddress.ip_address(raw_field)
                machine_subnet = ipaddress.ip_network(machine_ip).supernet(
                    new_prefix=MACHINE_IP_PREFIXES[ip_field])
            except ValueError:
                machine_ip = None
                machine_subnet = None
            machine_ip_data[ip_field] = ''
            if machine_ip:
                machine_ip_data.update({
                    ip_field:
                    str(machine_ip),
                    '_'.join((ip_field, 'rdns')):
                    get_rdns_lookup(str(machine_ip)),
                    '_'.join((ip_field, 'subnet')):
                    str(machine_subnet)
                })
            for field in fields:
                if field not in machine_ip_data:
                    machine_ip_data[field] = NO_DATA
        return machine_ip_data

    @staticmethod
    def merge_machine_ip(old_machine, new_machine):
        for ip_field, fields in MACHINE_IP_FIELDS.items():
            ip = new_machine.get(ip_field, None)
            old_ip = old_machine.get(ip_field, None)
            if not ip and old_ip:
                new_machine[ip_field] = old_ip
                for field in fields:
                    if field in old_machine:
                        new_machine[field] = old_machine[field]

    def find_new_machines(self, machines):
        '''parse switch structure to find new machines added to network
        since last call'''
        change_acls = False

        for machine in machines:
            machine['ether_vendor'] = get_ether_vendor(
                machine['mac'],
                '/poseidon/poseidon/metadata/nmap-mac-prefixes.txt')
            machine.update(self._parse_machine_ip(machine))
            if 'controller_type' not in machine:
                machine.update({'controller_type': 'none', 'controller': ''})
            trunk = False
            for sw in self.trunk_ports:
                if sw == machine['segment'] and self.trunk_ports[sw].split(
                        ',')[1] == str(
                            machine['port']) and self.trunk_ports[sw].split(
                                ',')[0] == machine['mac']:
                    trunk = True

            h = Endpoint.make_hash(machine, trunk=trunk)
            ep = self.endpoints.get(h, None)
            if ep is None:
                change_acls = True
                m = endpoint_factory(h)
                m.p_prev_states.append((m.state, int(time.time())))
                m.endpoint_data = deepcopy(machine)
                self.endpoints[m.name] = m
                self.logger.info('Detected new endpoint: {0}:{1}'.format(
                    m.name, machine))
            else:
                self.merge_machine_ip(ep.endpoint_data, machine)

            if ep and ep.endpoint_data != machine and not ep.ignore:
                diff_txt = self._diff_machine(ep.endpoint_data, machine)
                self.logger.info('Endpoint changed: {0}:{1}'.format(
                    h, diff_txt))
                change_acls = True
                ep.endpoint_data = deepcopy(machine)
                if ep.state == 'inactive' and machine['active'] == 1:
                    if ep.p_next_state in ['known', 'abnormal']:
                        # pytype: disable=attribute-error
                        ep.trigger(ep.p_next_state)
                    else:
                        ep.unknown()  # pytype: disable=attribute-error
                    ep.p_prev_states.append((ep.state, int(time.time())))
                elif ep.state != 'inactive' and machine['active'] == 0:
                    if ep.state in ['mirroring', 'reinvestigating']:
                        self.unmirror_endpoint(ep)
                        if ep.state == 'mirroring':
                            ep.p_next_state = 'mirror'
                        elif ep.state == 'reinvestigating':
                            ep.p_next_state = 'reinvestigate'
                    if ep.state in ['known', 'abnormal']:
                        ep.p_next_state = ep.state
                    ep.inactive()  # pytype: disable=attribute-error
                    ep.p_prev_states.append((ep.state, int(time.time())))

        if change_acls and self.controller['AUTOMATED_ACLS']:
            status = Actions(None, self.sdnc).update_acls(
                rules_file=self.controller['RULES_FILE'],
                endpoints=self.endpoints.values())
            if isinstance(status, list):
                self.logger.info(
                    'Automated ACLs did the following: {0}'.format(status[1]))
                for item in status[1]:
                    machine = {
                        'mac': item[1],
                        'segment': item[2],
                        'port': item[3]
                    }
                    h = Endpoint.make_hash(machine)
                    ep = self.endpoints.get(h, None)
                    if ep:
                        ep.acl_data.append(
                            ((item[0], item[4], item[5]), int(time.time())))
        self.refresh_endpoints()

    def store_endpoints(self):
        ''' store current endpoints in Redis. '''
        self.prc.store_endpoints(self.endpoints)

    def refresh_endpoints(self):
        self.logger.debug('refresh endpoints')
        self.store_endpoints()
        self.get_stored_endpoints()
예제 #6
0
class SDNConnect(object):
    def __init__(self):
        self.r = None
        self.first_time = True
        self.sdnc = None
        self.controller = Config().get_config()
        self.logger = logger
        self.get_sdn_context()
        self.endpoints = []
        self.investigations = 0
        self.connect_redis()

    def get_stored_endpoints(self):
        # load existing endpoints if any
        if self.r:
            try:
                p_endpoints = self.r.get('p_endpoints')
                if p_endpoints:
                    p_endpoints = ast.literal_eval(p_endpoints.decode('ascii'))
                    self.endpoints = []
                    for endpoint in p_endpoints:
                        self.endpoints.append(
                            EndpointDecoder(endpoint).get_endpoint())
            except Exception as e:  # pragma: no cover
                self.logger.error(
                    'Unable to get existing endpoints from Redis because {0}'.
                    format(str(e)))
        return

    def get_sdn_context(self):
        if 'TYPE' in self.controller and self.controller['TYPE'] == 'bcf':
            try:
                self.sdnc = BcfProxy(self.controller)
            except BaseException as e:  # pragma: no cover
                self.logger.error(
                    'BcfProxy could not connect to {0} because {1}'.format(
                        self.controller['URI'], e))
        elif 'TYPE' in self.controller and self.controller['TYPE'] == 'faucet':
            try:
                self.sdnc = FaucetProxy(self.controller)
            except BaseException as e:  # pragma: no cover
                self.logger.error(
                    'FaucetProxy could not connect to {0} because {1}'.format(
                        self.controller['URI'], e))
        else:
            self.logger.error('Unknown SDN controller config: {0}'.format(
                self.controller))

    def endpoint_by_name(self, name):
        self.get_stored_endpoints()
        for endpoint in self.endpoints:
            if endpoint.machine.name.strip() == name:
                return endpoint
        return None

    def endpoint_by_hash(self, hash_id):
        self.get_stored_endpoints()
        for endpoint in self.endpoints:
            if endpoint.name == hash_id:
                return endpoint
        return None

    def endpoints_by_ip(self, ip):
        self.get_stored_endpoints()
        endpoints = []
        for endpoint in self.endpoints:
            if ip in [
                    endpoint.endpoint_data['ipv4'],
                    endpoint.endpoint_data['ipv6']
            ]:
                endpoints.append(endpoint)
        return endpoints

    def endpoints_by_mac(self, mac):
        self.get_stored_endpoints()
        endpoints = []
        for endpoint in self.endpoints:
            if mac == endpoint.endpoint_data['mac']:
                endpoints.append(endpoint)
        return endpoints

    def collect_on(self, endpoint):
        self.get_stored_endpoints()
        # TODO
        return

    def remove_inactive_endpoints(self):
        self.get_stored_endpoints()
        remove_list = []
        for endpoint in self.endpoints:
            if endpoint.state == 'inactive':
                remove_list.append(endpoint)
        for endpoint in remove_list:
            self.endpoints.remove(endpoint)
        self.store_endpoints()
        return remove_list

    def ignore_inactive_endpoints(self):
        self.get_stored_endpoints()
        for ep in self.endpoints:
            if ep.state == 'ignore':
                ep.ignore = True
        self.store_endpoints()
        return

    def ignore_endpoint(self, endpoint):
        self.get_stored_endpoints()
        for ep in self.endpoints:
            if ep.name == endpoint.name:
                ep.ignore = True
        self.store_endpoints()
        return

    def clear_ignored_endpoint(self, endpoint):
        self.get_stored_endpoints()
        for ep in self.endpoints:
            if ep.name == endpoint.name:
                ep.ignore = False
        self.store_endpoints()
        return

    def remove_endpoint(self, endpoint):
        self.get_stored_endpoints()
        if endpoint in self.endpoints:
            self.endpoints.remove(endpoint)
        self.store_endpoints()
        return

    def remove_ignored_endpoints(self):
        self.get_stored_endpoints()
        remove_list = []
        for endpoint in self.endpoints:
            if endpoint.ignore:
                remove_list.append(endpoint)
        for endpoint in remove_list:
            self.endpoints.remove(endpoint)
        self.store_endpoints()
        return remove_list

    def show_endpoints(self, state, type_filter, all_devices):
        self.get_stored_endpoints()
        endpoints = []
        for endpoint in self.endpoints:
            if all_devices:
                endpoints.append(endpoint)
            elif state:
                if endpoint.state == state:
                    endpoints.append(endpoint)
            elif type_filter:
                if type_filter == 'ignored':
                    if endpoint.ignore:
                        endpoints.append(endpoint)
                # TODO
        return endpoints

    def check_endpoints(self, messages=None):
        retval = {}
        retval['machines'] = None
        retval['resp'] = 'bad'

        current = None
        parsed = None

        try:
            current = self.sdnc.get_endpoints(messages=messages)
            parsed = self.sdnc.format_endpoints(current)
            retval['machines'] = parsed
            retval['resp'] = 'ok'
        except BaseException as e:  # pragma: no cover
            self.logger.error(
                'Could not establish connection to {0} because {1}.'.format(
                    self.controller['URI'], e))
            retval[
                'controller'] = 'Could not establish connection to {0}.'.format(
                    self.controller['URI'])

        self.find_new_machines(parsed)

        return

    def connect_redis(self, host='redis', port=6379, db=0):
        self.r = None
        try:
            self.r = StrictRedis(host=host,
                                 port=port,
                                 db=db,
                                 socket_connect_timeout=2)
        except Exception as e:  # pragma: no cover
            self.logger.error('Failed connect to Redis because: {0}'.format(
                str(e)))
        return

    def find_new_machines(self, machines):
        '''parse switch structure to find new machines added to network
        since last call'''
        self.get_stored_endpoints()
        for machine in machines:
            h = Endpoint.make_hash(machine)
            ep = None
            for endpoint in self.endpoints:
                if h == endpoint.name:
                    ep = endpoint
            if ep is not None and ep.endpoint_data != machine and not ep.ignore:
                self.logger.info('Endpoint changed: {0}:{1}'.format(
                    h, machine))
                ep.endpoint_data = deepcopy(machine)
                if ep.state == 'inactive' and machine['active'] == 1:
                    if ep.p_next_state in ['known', 'abnormal']:
                        ep.trigger(ep.p_next_state)
                    else:
                        ep.unknown()
                    ep.p_prev_states.append((ep.state, int(time.time())))
                elif ep.state != 'inactive' and machine['active'] == 0:
                    if ep.state in ['mirroring', 'reinvestigating']:
                        status = Actions(ep, self.sdnc).unmirror_endpoint()
                        if not status:
                            self.logger.warning(
                                'Unable to unmirror the endpoint: {0}'.format(
                                    ep.name))
                        self.investigations -= 1
                        if ep.state == 'mirroring':
                            ep.p_next_state = 'mirror'
                        elif ep.state == 'reinvestigating':
                            ep.p_next_state = 'reinvestigate'
                    if ep.state in ['known', 'abnormal']:
                        ep.p_next_state = ep.state
                    ep.inactive()
                    ep.p_prev_states.append((ep.state, int(time.time())))
            elif ep is None:
                self.logger.info('Detected new endpoint: {0}:{1}'.format(
                    h, machine))
                m = Endpoint(h)
                m.p_prev_states.append((m.state, int(time.time())))
                m.endpoint_data = deepcopy(machine)
                self.endpoints.append(m)

        self.store_endpoints()
        return

    def store_endpoints(self):
        # store latest version of endpoints in redis
        if self.r:
            try:
                serialized_endpoints = []
                for endpoint in self.endpoints:
                    redis_endpoint_data = {}
                    redis_endpoint_data['name'] = str(endpoint.name)
                    redis_endpoint_data['state'] = str(endpoint.state)
                    redis_endpoint_data['ignore'] = str(endpoint.ignore)
                    redis_endpoint_data['endpoint_data'] = str(
                        endpoint.endpoint_data)
                    redis_endpoint_data['next_state'] = str(
                        endpoint.p_next_state)
                    redis_endpoint_data['prev_states'] = str(
                        endpoint.p_next_state)
                    self.r.hmset(endpoint.name, redis_endpoint_data)
                    mac = endpoint.endpoint_data['mac']
                    self.r.hmset(mac, {'poseidon_hash': str(endpoint.name)})
                    self.r.sadd('mac_addresses', mac)
                    if 'ipv4' in endpoint.endpoint_data and endpoint.endpoint_data[
                            'ipv4'] != 'None' and endpoint.endpoint_data[
                                'ipv4']:
                        self.r.hmset(endpoint.endpoint_data['ipv4'],
                                     {'poseidon_hash': str(endpoint.name)})
                        self.r.sadd('ip_addresses',
                                    endpoint.endpoint_data['ipv4'])
                    if 'ipv6' in endpoint.endpoint_data and endpoint.endpoint_data[
                            'ipv6'] != 'None' and endpoint.endpoint_data[
                                'ipv6']:
                        self.r.hmset(endpoint.endpoint_data['ipv6'],
                                     {'poseidon_hash': str(endpoint.name)})
                        self.r.sadd('ip_addresses',
                                    endpoint.endpoint_data['ipv6'])
                    serialized_endpoints.append(endpoint.encode())
                self.r.set('p_endpoints', str(serialized_endpoints))
            except Exception as e:  # pragma: no cover
                self.logger.error(
                    'Unable to store endpoints in Redis because {0}'.format(
                        str(e)))
        return