class SimpleHttpServer: def __init__(self): plugin_cfg_path = os.getenv('RECCE7_PLUGIN_CONFIG') or 'config/plugins.cfg' global_cfg_path = os.getenv('RECCE7_GLOBAL_CONFIG') or 'config/global.cfg' self.g_config = GlobalConfig(plugin_cfg_path, global_cfg_path) self.g_config.read_plugin_config() self.g_config.read_global_config() self.host = self.g_config.get_report_server_host() self.port = self.g_config.get_report_server_port() log_path = self.g_config['ReportServer']['reportserver.logName'] log_level = self.g_config['ReportServer']['reportserver.logLevel'] self.log = Logger(log_path, log_level).get('reportserver.server.SimpleHTTPServer.SimpleHTTPServer') def setupAndStart(self): server_address = (self.host, self.port) request_handler = RestRequestHandler # instantiate a server object httpd = HTTPServer (server_address, request_handler) print(time.asctime(), "Server Starting - %s:%s" % (self.host, self.port)) try: # start serving pages httpd.serve_forever () except KeyboardInterrupt: pass httpd.server_close() print(time.asctime(), "Server Stopped - %s:%s" % (self.host, self.port))
def test_get_json_by_time(self): plugin_cfg_path = os.getenv("RECCE7_PLUGIN_CONFIG") or "config/plugins.cfg" global_cfg_path = os.getenv("RECCE7_GLOBAL_CONFIG") or "config/global.cfg" global_config = GlobalConfig(plugin_cfg_path, global_cfg_path, True) global_config.read_global_config() global_config.read_plugin_config() test_start_date = datetime.datetime(1999, month=12, day=31, hour=23, minute=59, second=59) # successes = 0 # fails = 0 for count in range(0, 2): if count == 0: portnumber = 8082 elif count == 1: portnumber = 8083 else: portnumber = 8023 for x in range(0, 500): d = datetime.timedelta(weeks=x) query_date = test_start_date + d query_date_iso = query_date.isoformat() tableName = global_config.get_plugin_config(portnumber)["table"] query_string = "SELECT * FROM %s where (eventDateTime >= '%s')" % (tableName, query_date_iso) json_query = DatabaseHandler().query_db(query_string, db="TestDB.sqlite") for y in range(0, len(json_query) - 1): date = json_query[y].get("eventDateTime") self.assertGreaterEqual(date, query_date_iso)
class Table_Insert_test(unittest.TestCase): def setUp(self): self.test_db_dir = '/tests/database/test_database' self.test_db_file = '/tests/database/test_database/honeyDB.sqlite' # test configuration files self.plugins_config_file = 'tests/database/test_config/plugins.cfg' self.plugins_config_diff_file = 'tests/database/test_config/plugins_diff.cfg' self.plugins_config_diff_table_file = 'tests/database/test_config/plugins_diff_table.cfg' self.global_config_file = 'tests/database/test_config/global.cfg' # create global config instance self.gci = GlobalConfig(self.plugins_config_file,self.global_config_file) self.gci.read_global_config() self.gci.read_plugin_config() self.telnet_data = {'test_telnet':{'session':'abcdefghijklmnop','eventDateTime':'01-02-2016 11:22:33.123', 'peerAddress':'24.33.21.123', 'localAddress':'192.168.0.55', 'input_type':'a string','user_input':'another string'}} self.telnet_data_noip = {'test_telnet':{'session':'abcdefghijklmnop','eventDateTime':'01-02-2016 11:22:33.123', 'input_type':'a string','user_input':'another string'}} @patch.object(Logger,'__new__') def test_telnet_all_values(self,log): db = database.Database() db.create_default_database() validator = datavalidator.DataValidator() Table_Insert.prepare_data_for_insertion(validator.get_schema(), self.telnet_data) connection = sqlite3.connect(self.gci['Database']['path']) cursor = connection.cursor() row = cursor.execute('select * from test_telnet;').fetchall()[0] check_list = self.helper_get_values_out(self.telnet_data) self.assertTrue(set(row) > set(check_list)) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_telnet_missing_non_required_values(self,log): db = database.Database() db.create_default_database() validator = datavalidator.DataValidator() Table_Insert.prepare_data_for_insertion(validator.get_schema(), self.telnet_data_noip) connection = sqlite3.connect(self.gci['Database']['path']) cursor = connection.cursor() row = cursor.execute('select * from test_telnet;').fetchall()[0] check_list = self.helper_get_values_out(self.telnet_data_noip) bad_check_list = self.helper_get_values_out(self.telnet_data) self.assertTrue(set(row) > set(check_list)) self.assertFalse(set(row) > set(bad_check_list)) shutil.rmtree(os.getcwd() + self.test_db_dir) def helper_get_values_out(self,dictionary): inner_dict = dictionary[util.get_first_key_value_of_dictionary(dictionary)] list = [] for item in inner_dict: list.append(inner_dict[item]) return list
class datamanager_test(unittest.TestCase): def setUp(self): self.test_db_dir = "/tests/database/test_database" self.test_db_file = "/tests/database/test_database/honeyDB.sqlite" # test configuration files self.plugins_config_file = "tests/database/test_config/plugins.cfg" self.plugins_config_diff_file = "tests/database/test_config/plugins_diff.cfg" self.plugins_config_diff_table_file = "tests/database/test_config/plugins_diff_table.cfg" self.global_config_file = "tests/database/test_config/global.cfg" # create global config instance self.gci = GlobalConfig(self.plugins_config_file, self.global_config_file) self.gci.read_global_config() self.gci.read_plugin_config() @patch.object(Logger, "__new__") def test_datamanager_init(self, log): dm = datamanager.DataManager() self.assertTrue(os.path.isdir(os.getcwd() + self.test_db_dir)) self.assertTrue(os.path.isfile(os.getcwd() + self.test_db_file)) self.assertIsInstance(dm.q, dataqueue.DataQueue) shutil.rmtree(os.getcwd() + self.test_db_dir)
def test_connect(self): # Negative testing self.assertIsNone(DatabaseHandler().connect("database")) self.assertIsNone(DatabaseHandler().connect("database.db")) self.assertIsNone(DatabaseHandler().connect("asdl;kfjeiei")) self.assertIsNone(DatabaseHandler().connect("./honeyDB/honeyDB.sqllite")) self.assertIsNone(DatabaseHandler().connect("./honeyDB/honeyDB.db")) self.assertIsNone(DatabaseHandler().connect(" ")) self.assertIsNone(DatabaseHandler().connect("")) # Testing for correct DB plugin_cfg_path = "tests/reportserver/testconfig/plugins.cfg" global_cfg_path = "tests/reportserver/testconfig/global.cfg" global_config = GlobalConfig(plugin_cfg_path, global_cfg_path, True) global_config.read_global_config() global_config.read_plugin_config() db = Database() db.create_db_dir() db.create_db() db_path = global_config["Database"]["path"] self.assertTrue(sqlite3.connect(db_path)) self.assertTrue(DatabaseHandler().connect(db_path)) self.assertTrue(DatabaseHandler().connect(None))
def setUp(self): # Testing for correct DB plugin_cfg_path = "tests/reportserver/testconfig/plugins.cfg" global_cfg_path = "tests/reportserver/testconfig/global.cfg" global_config = GlobalConfig(plugin_cfg_path, global_cfg_path, True) global_config.read_global_config() global_config.read_plugin_config() conn = sqlite3.connect("TestDB.sqlite") c = conn.cursor() c.execute("""CREATE TABLE test_http (port int, data text, eventDateTime text)""") c.execute("""CREATE TABLE test_http2 (port int, data text, eventDateTime text)""") c.execute("""CREATE TABLE test_telnet (port int, data text, eventDateTime text)""") test_start_date = datetime.datetime(1999, month=12, day=31, hour=23, minute=59, second=59) for x in range(0, 500): d = datetime.timedelta(weeks=x) insert_date = test_start_date + d insert_date_iso = insert_date.isoformat() c.execute("INSERT INTO test_http VALUES (8082,'TEXT','%s')" % insert_date_iso) c.execute("INSERT INTO test_http2 VALUES (8083,'TEXT','%s')" % insert_date_iso) c.execute("INSERT INTO test_telnet VALUES (8023,'TEXT','%s')" % insert_date_iso) conn.commit()
class datavalidator_test(unittest.TestCase): def setUp(self): self.test_db_dir = '/tests/database/test_database' self.test_db_file = '/tests/database/test_database/honeyDB.sqlite' # test configuration files self.plugins_config_file = 'tests/database/test_config/plugins.cfg' self.plugins_config_diff_file = 'tests/database/test_config/plugins_diff.cfg' self.plugins_config_diff_table_file = 'tests/database/test_config/plugins_diff_table.cfg' self.global_config_file = 'tests/database/test_config/global.cfg' # create global config instance self.gci = GlobalConfig(self.plugins_config_file,self.global_config_file) self.gci.read_global_config() self.gci.read_plugin_config() self.tables_test = ['p0f', 'ipInfo', 'sessions', 'test_http', 'test_http2', 'test_telnet'] self.table_schema_test = {'sessions': [(0, 'session', 'TEXT', 0, None, 1), (1, 'table_name', 'TEXT', 1, None, 2)], 'test_http2': [(0, 'ID', 'INTEGER', 1, None, 1), (1, 'session', 'TEXT', 0, None, 0), (2, 'eventDateTime', 'TEXT', 0, None, 0), (3, 'peerAddress', 'TEXT', 0, None, 0), (4, 'localAddress', 'TEXT', 0, None, 0), (5, 'command', 'TEXT', 0, None, 0), (6, 'path', 'TEXT', 0, None, 0), (7, 'headers', 'TEXT', 0, None, 0), (8, 'body', 'TEXT', 0, None, 0)], 'test_http': [(0, 'ID', 'INTEGER', 1, None, 1), (1, 'session', 'TEXT', 0, None, 0), (2, 'eventDateTime', 'TEXT', 0, None, 0), (3, 'peerAddress', 'TEXT', 0, None, 0), (4, 'localAddress', 'TEXT', 0, None, 0), (5, 'command', 'TEXT', 0, None, 0), (6, 'path', 'TEXT', 0, None, 0), (7, 'headers', 'TEXT', 0, None, 0), (8, 'body', 'TEXT', 0, None, 0)], 'p0f': [(0, 'session', 'TEXT', 1, None, 1), (1, 'first_seen', 'TEXT', 0, None, 0), (2, 'last_seen', 'TEXT', 0, None, 0), (3, 'uptime', 'INTEGER', 0, None, 0), (4, 'last_nat', 'TEXT', 0, None, 0), (5, 'last_chg', 'TEXT', 0, None, 0), (6, 'distance', 'INTEGER', 0, None, 0), (7, 'bad_sw', 'INTEGER', 0, None, 0), (8, 'os_name', 'TEXT', 0, None, 0), (9, 'os_flavor', 'TEXT', 0, None, 0), (10, 'os_match_q', 'INTEGER', 0, None, 0), (11, 'http_name', 'TEXT', 0, None, 0), (12, 'http_flavor', 'TEXT', 0, None, 0), (13, 'total_conn', 'INTEGER', 0, None, 0), (14, 'link_type', 'TEXT', 0, None, 0), (15, 'language', 'TEXT', 0, None, 0)], 'ipInfo': [(0, 'ip', 'TEXT', 1, None, 1), (1, 'plugin_instance', 'TEXT', 1, None, 2), (2, 'timestamp', 'TEXT', 1, None, 0), (3, 'hostname', 'TEXT', 0, None, 0), (4, 'city', 'TEXT', 0, None, 0), (5, 'region', 'TEXT', 0, None, 0), (6, 'country', 'TEXT', 0, None, 0), (7, 'lat', 'REAL', 0, None, 0), (8, 'long', 'REAL', 0, None, 0), (9, 'org', 'TEXT', 0, None, 0), (10, 'postal', 'TEXT', 0, None, 0)], 'test_telnet': [(0, 'ID', 'INTEGER', 1, None, 1), (1, 'session', 'TEXT', 0, None, 0), (2, 'eventDateTime', 'TEXT', 0, None, 0), (3, 'peerAddress', 'TEXT', 0, None, 0), (4, 'localAddress', 'TEXT', 0, None, 0), (5, 'input_type', 'TEXT', 0, None, 0), (6, 'user_input', 'TEXT', 0, None, 0)]} @patch.object(Logger,'__new__') def test_get_schema_from_database(self,log): # will call the constructor because this calls this method db = database.Database() db.create_default_database() validator = datavalidator.DataValidator() self.assertTrue(set(validator.tables) == set(self.tables_test)) self.assertEqual(validator.table_schema,self.table_schema_test) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_get_tables(self,log): db = database.Database() db.create_default_database() validator = datavalidator.DataValidator() self.assertIsInstance(validator.get_tables(),list) self.assertTrue(set(validator.get_tables()) == set(self.tables_test)) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_get_schema(self,log): db = database.Database() db.create_default_database() validator = datavalidator.DataValidator() self.assertIsInstance(validator.get_schema(),dict) self.assertEqual(validator.get_schema(),self.table_schema_test) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_check_value_len(self, log): good_dict = {'table1':{'col1':'val1','col2':'val2'}} bad_dict = {'table1':{'col1':'val1','col2':'val2'},'table2':{'col3':'val3','col4':'val4'}} db = database.Database() db.create_default_database() validator = datavalidator.DataValidator() log.error = unittest.mock.Mock() log.error.reset_mock() self.assertTrue(validator.check_value_len(good_dict)) self.assertFalse(db.log.error.called) self.assertFalse(validator.check_value_len(bad_dict)) self.assertTrue(db.log.error.called) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_check_value_is_dict(self,log): good_dict = {'table1':{'col1':'val1','col2':'val2'}} bad_dict = ['im a list'] db = database.Database() db.create_default_database() validator = datavalidator.DataValidator() log.error = unittest.mock.Mock() log.error.reset_mock() self.assertTrue(validator.check_value_is_dict(good_dict)) self.assertFalse(db.log.error.called) self.assertFalse(validator.check_value_is_dict(bad_dict)) self.assertTrue(db.log.error.called) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_check_key_in_dict_string(self,log): good_dict = {'table1':{'col1':'val1','col2':'val2'}} bad_dict = {1:{'col1':'val1'}} db = database.Database() db.create_default_database() validator = datavalidator.DataValidator() log.error = unittest.mock.Mock() log.error.reset_mock() self.assertTrue(validator.check_key_in_dict_string(good_dict)) self.assertFalse(db.log.error.called) self.assertFalse(validator.check_key_in_dict_string(bad_dict)) self.assertTrue(db.log.error.called) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_check_key_is_valid_table_name(self,log): good_dict = {'test_http':{'col1':'val1','col2':'val2'}} bad_dict = {'test_table':{'col1':'val1'}} db = database.Database() db.create_default_database() validator = datavalidator.DataValidator() log.error = unittest.mock.Mock() log.error.reset_mock() self.assertTrue(validator.check_key_is_valid_table_name(good_dict)) self.assertFalse(db.log.error.called) self.assertFalse(validator.check_key_is_valid_table_name(bad_dict)) self.assertTrue(db.log.error.called) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_check_row_value_is_dict(self,log): good_dict = {'test_table':{'col1':'val1','col2':'val2'}} bad_dict = {'test_table':'i am not a dictionary'} db = database.Database() db.create_default_database() validator = datavalidator.DataValidator() log.error = unittest.mock.Mock() log.error.reset_mock() self.assertTrue(validator.check_row_value_is_dict(good_dict)) self.assertFalse(db.log.error.called) self.assertFalse(validator.check_row_value_is_dict(bad_dict)) self.assertTrue(db.log.error.called) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_check_all_col_names_strings(self,log): good_dict = {'test_table':{'col1':'val1','col2':'val2'}} bad_dict = {'test_table':{1:'val1','string':'val2'}} db = database.Database() db.create_default_database() validator = datavalidator.DataValidator() self.assertTrue(validator.check_all_col_names_strings(good_dict)) self.assertFalse(validator.check_all_col_names_strings(bad_dict)) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_check_all_col_exist(self,log): good_dict = {'test_http':{'session': 'val1','eventDateTime': 'val2','peerAddress': 'val3', 'localAddress': 'val4','command': 'val5','path': 'val6', 'headers': 'val7','body': 'val8'}} bad_dict = {'test_http2':{'session': 'val1','eventDateTime': 'val2','peerAddress': 'val3', 'XYZ': 'val4','command': 'val5','path': 'val6', 'headers': 'val7','body': 'val8'}} missing_col = {'test_http':{'session': 'val1','eventDateTime': 'val2','peerAddress': 'val3', 'localAddress': 'val4', 'path': 'val6', 'headers': 'val7', 'body': 'val8'}} db = database.Database() db.create_default_database() validator = datavalidator.DataValidator() log.error = unittest.mock.Mock() log.error.reset_mock() self.assertTrue(validator.check_all_col_exist(good_dict)) self.assertFalse(db.log.error.called) self.assertTrue(validator.check_all_col_exist(missing_col)) self.assertFalse(db.log.error.called) self.assertFalse(validator.check_all_col_exist(bad_dict)) self.assertTrue(db.log.error.called) shutil.rmtree(os.getcwd() + self.test_db_dir)
class WorldmapServiceHandler(): def __init__(self): self.log = Logger().get('reportserver.manager.WorldmapServiceManager.py') self.global_config = GlobalConfig() self.global_config.read_plugin_config() self.global_config.read_global_config() def process(self, rqst, path_tokens, query_tokens): global have_basemap if not have_basemap: err_msg = \ ('<html><head><title>WorldMap</title></head><body>' 'To enable WorldMap generation, please visit ' '<a href="https://recce7.github.io/">the documentation</a> and ' 'follow the directions for installing the Basemap library.' '</body></html>') rqst.send_response(200) #todo make this configurable for allow-origin rqst.send_header("Access-Control-Allow-Origin","http://localhost:8000") rqst.send_header('Content-Type', 'text/html') rqst.send_header('Content-Length', len(err_msg)) rqst.end_headers() rqst.flush_headers() rqst.wfile.write(bytes(err_msg, "utf-8")) rqst.wfile.flush() return uom = None units = None self.log.info("processing ipaddress request:" + str(path_tokens) + str(query_tokens)) try: time_period = utilities.validate_time_period(query_tokens) uom = time_period[0] units = time_period[1] except ValueError: rqst.badRequest(units) return if len(path_tokens) >= 5: rqst.badRequest() return else: self.construct_worldmap(rqst, uom, units) def construct_worldmap(self, rqst, uom, units): #call to construct port list #find unique ips by port #merge the results togoether #build the map #probably want to look at the PortsServiceHandler.py or IpsServiceHandler.py to follow those patterns. ip_map = pickle.loads(pickle_bytes) pts = self.get_point_list(uom, units) for pt in pts: srclat, srclong = pt x, y = ip_map(srclong, srclat) plt.plot(x, y, 'o', color='#ff0000', ms=2.7, markeredgewidth=1.0) plt.savefig('reportserver/worldmap.png', dpi=600) img = Image.open('reportserver/worldmap.png') draw = ImageDraw.Draw(img) font = ImageFont.truetype( "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 175) draw.text((50, 50), "Unique IP addresses: last %s %s" % (units, uom), (0, 0, 0), font=font) font = ImageFont.truetype( "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 125) draw.text((50, 325), "Total: %s" % (len(pts)), (0, 0, 0), font=font) # draw = ImageDraw.Draw(img) # draw = ImageDraw.Draw(img) img.save("reportserver/worldmap.png") rqst.sendPngResponse("reportserver/worldmap.png", 200) def get_point_list(self, uom, units): begin_date = dateTimeUtility.get_begin_date_iso(uom, units) query_string = ('select lat,long ' 'from (' 'select distinct lat,long,timestamp, ip ' 'from ipInfo ' 'where lat is not null ' 'and long is not null ' 'and datetime(timestamp) > datetime(\'' + begin_date + '\')' ');') connection = sqlite3.connect(self.global_config['Database']['path']) cursor = connection.cursor() return cursor.execute(query_string).fetchall()
class GlobalConfig_Test(unittest.TestCase): def setUp(self): self.gconfig = GlobalConfig(test_cfg_path, test_global_cfg_path, refresh=True) self.gconfig.read_plugin_config() self.gconfig.read_global_config() def test_getInstance(self): gconfig2 = GlobalConfig() self.assertEqual(str(self.gconfig),str(gconfig2),"these 2 objects should equal") gconfig3 = GlobalConfig() self.assertEqual(str(self.gconfig), str(gconfig3), "these 2 objects should equal") self.assertEqual(str(gconfig2), str(gconfig3), "these 2 objects should equal") def test_getPorts(self): ports = self.gconfig.get_ports() self.assertEqual(len(ports), 2, "expected 2 ports in test.cfg found: " + str(len(ports))) for port in ports: print("found: " + str(port)) def test_getReportServerConfig(self): host = self.gconfig.get_report_server_host() port = self.gconfig.get_report_server_port() self.assertEqual(host, "", "expected host to be ''") self.assertEqual(port, 8080, "expected port to be '8080' ") def test_getReportServerHost(self): self.assertEqual("", self.gconfig.get_report_server_host()) def test_getReportServerPort(self): self.assertEqual(8080, self.gconfig.get_report_server_port()) def test_refresh_instance(self): gconfig2 = GlobalConfig(test_cfg_path, test_global_cfg_path, refresh=True) self.assertNotEqual(str(self.gconfig), str(gconfig2), "these 2 objects should NOT equal when refresh set to True") def test_refresh_instance_same(self): gconfig2 = GlobalConfig() self.assertEqual(str(self.gconfig), str(gconfig2), "these 2 objects should equal when False is set for Refresh") gconfig2 = GlobalConfig() self.assertEqual(str(self.gconfig), str(gconfig2), "these 2 objects should equal with default of False") def test_get_date_time_name(self): self.assertEqual("eventDateTime", self.gconfig['Database']['datetime.name']) def test_get_db_peerAddress_nameself(self): self.assertEqual("peerAddress", self.gconfig['Database']['peerAddress.name']) def test_get_db_localAddress_name(self): self.assertEqual("localAddress", self.gconfig['Database']['localAddress.name'])
class _Framework: def __init__(self, plugin_cfg_path, global_cfg_path): self._global_config = GlobalConfig(plugin_cfg_path, global_cfg_path) self._plugin_imports = {} self._listener_list= {} self._running_plugins_list = [] self._data_manager = None self._shutting_down = False self._log = None self._pid = os.getpid() def start(self): self.set_shutdown_hook() print('Press Ctrl+C to exit.') if not self.drop_permissions(): return self._global_config.read_global_config() self.start_logging() self._global_config.read_plugin_config() self._data_manager = DataManager() self._data_manager.start() self.start_listeners() def start_logging(self): log_path = self._global_config['Framework']['logName'] log_level = self._global_config['Framework']['logLevel'] self._log = Logger(log_path, log_level).get('framework.frmwork.Framework') self._log.info('RECCE7 started (PID %d)' % self._pid) @staticmethod def drop_permissions(): if os.getuid() != 0: return True dist_name = os.getenv('RECCE7_OS_DIST') users_dict = { 'centos': ('nobody', 'nobody'), 'debian': ('nobody', 'nogroup') } if dist_name not in users_dict: print( 'Unable to lower permission level - not continuing as\n' 'superuser. Please set the environment variable\n' 'RECCE7_OS_DIST to one of:\n\tcentos\n\tdebian\n' 'or rerun as a non-superuser.') return False lowperm_user = users_dict[dist_name] nobody_uid = pwd.getpwnam(lowperm_user[0]).pw_uid nogroup_gid = grp.getgrnam(lowperm_user[1]).gr_gid os.setgroups([]) os.setgid(nogroup_gid) os.setuid(nobody_uid) os.umask(0o077) return True def create_import_entry(self, port, name, clsname): imp = import_module('plugins.' + name) self._plugin_imports[port] = getattr(imp, clsname) def start_listeners(self): ports = self._global_config.get_ports() for port in ports: plugin_config = self._global_config.get_plugin_config(port) module = plugin_config['module'] clsname = plugin_config['moduleClass'] self.create_import_entry(port, module, clsname) address = self._global_config['Framework']['listeningAddress'] listener = NetworkListener(address, plugin_config, self) listener.start() self._listener_list[port] = listener def set_shutdown_hook(self): signal.signal(signal.SIGINT, self.shutdown) def shutdown(self, *args): self._shutting_down = True self._log.debug('Shutting down network listeners') for listener in self._listener_list.values(): listener.shutdown() self._log.debug('Shutting down plugins') for plugin in self._running_plugins_list: plugin.shutdown() self._log.debug('Shutting down data manager') self._data_manager.shutdown() print('Goodbye!') # # Framework API # def get_config(self, port): """ Returns the configuration dictionary for the plugin running on the specified port. :param port: a port number associated with a loaded plugin :return: a plugin configuration dictionary """ return self._global_config.get_plugin_config(port) def spawn(self, socket, config): """ Spawns the plugin configured by 'config' with the provided (accepted) socket. :param socket: an open, accepted socket returned by socket.accept() :param config: the plugin configuration dictionary describing the plugin to spawn :return: a reference to the plugin that was spawned """ # ToDo Throw exception if plugin class not found plugin_class = self._plugin_imports[config['port']] plugin = plugin_class(socket, config, self) plugin.start() self._running_plugins_list.append(plugin) return plugin def insert_data(self, data): """ Inserts the provided data into the data queue so that it can be pushed to the database. :param data: data object to add to the database """ self._data_manager.insert_data(data) def plugin_stopped(self, plugin): """ Tells the framework that the specified plugin has stopped running and doesn't need to be shutdown explicitly on program exit. :param plugin: a reference to a plugin """ if self._shutting_down: return self._running_plugins_list.remove(plugin)
class database_test(unittest.TestCase): def setUp(self): self.test_db_dir = '/tests/database/test_database' self.test_db_file = '/tests/database/test_database/honeyDB.sqlite' # test configuration files self.plugins_config_file = 'tests/database/test_config/plugins.cfg' self.plugins_config_diff_file = 'tests/database/test_config/plugins_diff.cfg' self.plugins_config_diff_table_file = 'tests/database/test_config/plugins_diff_table.cfg' self.global_config_file = 'tests/database/test_config/global.cfg' # create global config instance self.gci = GlobalConfig(self.plugins_config_file,self.global_config_file) self.gci.read_global_config() self.gci.read_plugin_config() # patch the Logger new method so that it doesn't create the log file, we do not need to test that the logging # works just that the log statements are called. @patch.object(Logger,'__new__') def test_database_init(self,log): db = database.Database() self.assertIsInstance(db.global_config,GlobalConfig._GlobalConfig) self.assertTrue(log.called) @patch.object(Logger,'__new__') def test_database_create_default_database(self,log): db = database.Database() db.create_default_database() validator = datavalidator.DataValidator() # check that the directory exists self.assertTrue(os.path.isdir(os.getcwd() + self.test_db_dir)) # check that the database file exists self.assertTrue(os.path.isfile(os.getcwd() + self.test_db_file)) # get the table names from the database schema_table_list = validator.get_tables() # get the user defined tables from the configuration file config_table_list = util.get_config_table_list(self.gci.get_ports(), self.gci.get_plugin_dictionary()) # check that the non user defined table p0f exists self.assertTrue('p0f' in schema_table_list) # check that the non user defined table ipInfo exists self.assertTrue('ipInfo' in schema_table_list) # check that the non user defined table sessions exists self.assertTrue('sessions' in schema_table_list) # check that the user defined tables are a subset of the tables in the database schema self.assertTrue(set(config_table_list) < set(schema_table_list)) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_database_create_db_dir(self,log): db = database.Database() db.create_db_dir() self.assertTrue(os.path.isdir(os.getcwd() + self.test_db_dir)) self.assertTrue(log.called) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_database_create_db_dir_already_exists(self,log): os.mkdir(os.getcwd() + self.test_db_dir) self.assertTrue(os.path.isdir(os.getcwd() + self.test_db_dir)) db = database.Database() log.reset_mock() db.create_db_dir() self.assertFalse(log.called) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_database_create_db(self,log): db = database.Database() db.create_db_dir() db.create_db() self.assertTrue(os.path.isfile(os.getcwd() + self.test_db_file)) self.assertTrue(log.called) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_database_update_schema(self,log): db = database.Database() db.create_db_dir() db.create_db() util.run_db_scripts(self.gci) db.update_schema() validator = datavalidator.DataValidator() schema = validator.get_schema() self.assertTrue(schema['test_http'][5][1] == 'command') self.assertTrue(schema['test_http2'][6][1] == 'path') self.assertTrue(len(schema['test_telnet']) == 7) # set global config instance to the differing column config file self.gci = GlobalConfig(self.plugins_config_diff_file,self.global_config_file,True) self.gci.read_global_config() self.gci.read_plugin_config() db2 = database.Database() db2.update_schema() validator2 = datavalidator.DataValidator() schema2 = validator2.get_schema() self.assertTrue(schema2['test_http'][5][1] == 'unit_test_data_1') self.assertTrue(schema2['test_http2'][6][1] == 'unit_test_data_2') self.assertTrue(len(schema2['test_telnet']) == 8) self.assertTrue(schema2['test_telnet'][7][1] == 'unit_test_data_3') # set global config instance back to normal self.gci = GlobalConfig(self.plugins_config_file,self.global_config_file,True) self.gci.read_global_config() self.gci.read_plugin_config() shutil.rmtree(os.getcwd() + self.test_db_dir)
class dataqueue_test(unittest.TestCase): def setUp(self): self.test_db_dir = '/tests/database/test_database' self.test_db_file = '/tests/database/test_database/honeyDB.sqlite' # test configuration files self.plugins_config_file = 'tests/database/test_config/plugins.cfg' self.plugins_config_diff_file = 'tests/database/test_config/plugins_diff.cfg' self.plugins_config_diff_table_file = 'tests/database/test_config/plugins_diff_table.cfg' self.global_config_file = 'tests/database/test_config/global.cfg' # create global config instance self.gci = GlobalConfig(self.plugins_config_file,self.global_config_file) self.gci.read_global_config() self.gci.read_plugin_config() @patch.object(Logger,'__new__') def test_dataqueue_init(self,log): db = database.Database() db.create_default_database() dq = dataqueue.DataQueue() self.assertIsInstance(dq.dataQueue, queue.Queue) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_insert_into_dataqueue(self,log): insert_dict = {'test_telnet':{'session':'abc', 'eventDateTime':'02-03-2016 04:15:37.037', 'peerAddress':'41.26.134.3', 'localAddress':'192.168.0.51', 'input_type':'test', 'user_input':'test input'}} insert_dict2 = {'test_telnet':{'session':'abcd', 'eventDateTime':'02-04-2016 04:15:37.037', 'peerAddress':'41.27.234.3', 'localAddress':'192.168.0.42', 'input_type':'test1', 'user_input':'test input1'}} db = database.Database() db.create_default_database() dq = dataqueue.DataQueue() self.assertTrue(dq.check_empty()) dq.insert_into_data_queue(insert_dict) self.assertFalse(dq.check_empty()) dq.insert_into_data_queue(insert_dict2) self.assertTrue(dq.dataQueue.qsize() == 2) result = dq.get_next_item() self.assertTrue(insert_dict, result) self.assertTrue(dq.dataQueue.qsize() == 1) result2 = dq.get_next_item() self.assertTrue(insert_dict2, result2) self.assertTrue(dq.check_empty()) shutil.rmtree(os.getcwd() + self.test_db_dir) @patch.object(Logger,'__new__') def test_insert_into_dataqueue_bad_value(self,log): insert_dict = {'test_telnet':{'session':'abc', 'eventDateTime':'02-03-2016 04:15:37.037', 'peerAddress':'41.26.134.3', 'localAddress':'192.168.0.51', 'input_type':'test', 'user_input':'test input', 'fake_column':'bad data'}} db = database.Database() db.create_default_database() dq = dataqueue.DataQueue() self.assertFalse(dq.insert_into_data_queue(insert_dict)) shutil.rmtree(os.getcwd() + self.test_db_dir)