示例#1
0
class SimpleHttpServer:
    def __init__(self):
        plugin_cfg_path = os.getenv('RECCE7_PLUGIN_CONFIG') or 'config/plugins.cfg'
        global_cfg_path = os.getenv('RECCE7_GLOBAL_CONFIG') or 'config/global.cfg'
        self.g_config = GlobalConfig(plugin_cfg_path, global_cfg_path)
        self.g_config.read_plugin_config()
        self.g_config.read_global_config()
        self.host = self.g_config.get_report_server_host()
        self.port = self.g_config.get_report_server_port()
        log_path = self.g_config['ReportServer']['reportserver.logName']
        log_level = self.g_config['ReportServer']['reportserver.logLevel']
        self.log = Logger(log_path, log_level).get('reportserver.server.SimpleHTTPServer.SimpleHTTPServer')

    def setupAndStart(self):

        server_address = (self.host, self.port)

        request_handler = RestRequestHandler

        # instantiate a server object
        httpd = HTTPServer (server_address, request_handler)
        print(time.asctime(), "Server Starting - %s:%s" % (self.host, self.port))

        try:
            # start serving pages
            httpd.serve_forever ()
        except KeyboardInterrupt:
            pass

        httpd.server_close()
        print(time.asctime(), "Server Stopped - %s:%s" % (self.host, self.port))
示例#2
0
 def test_get_json_by_time(self):
     plugin_cfg_path = os.getenv("RECCE7_PLUGIN_CONFIG") or "config/plugins.cfg"
     global_cfg_path = os.getenv("RECCE7_GLOBAL_CONFIG") or "config/global.cfg"
     global_config = GlobalConfig(plugin_cfg_path, global_cfg_path, True)
     global_config.read_global_config()
     global_config.read_plugin_config()
     test_start_date = datetime.datetime(1999, month=12, day=31, hour=23, minute=59, second=59)
     #        successes = 0
     #        fails = 0
     for count in range(0, 2):
         if count == 0:
             portnumber = 8082
         elif count == 1:
             portnumber = 8083
         else:
             portnumber = 8023
         for x in range(0, 500):
             d = datetime.timedelta(weeks=x)
             query_date = test_start_date + d
             query_date_iso = query_date.isoformat()
             tableName = global_config.get_plugin_config(portnumber)["table"]
             query_string = "SELECT * FROM %s where (eventDateTime >= '%s')" % (tableName, query_date_iso)
             json_query = DatabaseHandler().query_db(query_string, db="TestDB.sqlite")
             for y in range(0, len(json_query) - 1):
                 date = json_query[y].get("eventDateTime")
                 self.assertGreaterEqual(date, query_date_iso)
示例#3
0
class Table_Insert_test(unittest.TestCase):

    def setUp(self):
        self.test_db_dir = '/tests/database/test_database'
        self.test_db_file = '/tests/database/test_database/honeyDB.sqlite'
        # test configuration files
        self.plugins_config_file = 'tests/database/test_config/plugins.cfg'
        self.plugins_config_diff_file = 'tests/database/test_config/plugins_diff.cfg'
        self.plugins_config_diff_table_file = 'tests/database/test_config/plugins_diff_table.cfg'
        self.global_config_file = 'tests/database/test_config/global.cfg'
        # create global config instance
        self.gci = GlobalConfig(self.plugins_config_file,self.global_config_file)
        self.gci.read_global_config()
        self.gci.read_plugin_config()
        self.telnet_data = {'test_telnet':{'session':'abcdefghijklmnop','eventDateTime':'01-02-2016 11:22:33.123',
                                           'peerAddress':'24.33.21.123', 'localAddress':'192.168.0.55',
                                           'input_type':'a string','user_input':'another string'}}
        self.telnet_data_noip = {'test_telnet':{'session':'abcdefghijklmnop','eventDateTime':'01-02-2016 11:22:33.123',
                                           'input_type':'a string','user_input':'another string'}}

    @patch.object(Logger,'__new__')
    def test_telnet_all_values(self,log):
        db = database.Database()
        db.create_default_database()
        validator = datavalidator.DataValidator()
        Table_Insert.prepare_data_for_insertion(validator.get_schema(), self.telnet_data)
        connection = sqlite3.connect(self.gci['Database']['path'])
        cursor = connection.cursor()
        row = cursor.execute('select * from test_telnet;').fetchall()[0]
        check_list = self.helper_get_values_out(self.telnet_data)
        self.assertTrue(set(row) > set(check_list))
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_telnet_missing_non_required_values(self,log):
        db = database.Database()
        db.create_default_database()
        validator = datavalidator.DataValidator()
        Table_Insert.prepare_data_for_insertion(validator.get_schema(), self.telnet_data_noip)
        connection = sqlite3.connect(self.gci['Database']['path'])
        cursor = connection.cursor()
        row = cursor.execute('select * from test_telnet;').fetchall()[0]
        check_list = self.helper_get_values_out(self.telnet_data_noip)
        bad_check_list = self.helper_get_values_out(self.telnet_data)
        self.assertTrue(set(row) > set(check_list))
        self.assertFalse(set(row) > set(bad_check_list))
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    def helper_get_values_out(self,dictionary):
        inner_dict = dictionary[util.get_first_key_value_of_dictionary(dictionary)]
        list = []
        for item in inner_dict:
            list.append(inner_dict[item])
        return list
示例#4
0
class datamanager_test(unittest.TestCase):
    def setUp(self):
        self.test_db_dir = "/tests/database/test_database"
        self.test_db_file = "/tests/database/test_database/honeyDB.sqlite"
        # test configuration files
        self.plugins_config_file = "tests/database/test_config/plugins.cfg"
        self.plugins_config_diff_file = "tests/database/test_config/plugins_diff.cfg"
        self.plugins_config_diff_table_file = "tests/database/test_config/plugins_diff_table.cfg"
        self.global_config_file = "tests/database/test_config/global.cfg"
        # create global config instance
        self.gci = GlobalConfig(self.plugins_config_file, self.global_config_file)
        self.gci.read_global_config()
        self.gci.read_plugin_config()

    @patch.object(Logger, "__new__")
    def test_datamanager_init(self, log):
        dm = datamanager.DataManager()
        self.assertTrue(os.path.isdir(os.getcwd() + self.test_db_dir))
        self.assertTrue(os.path.isfile(os.getcwd() + self.test_db_file))
        self.assertIsInstance(dm.q, dataqueue.DataQueue)
        shutil.rmtree(os.getcwd() + self.test_db_dir)
示例#5
0
    def test_connect(self):
        # Negative testing
        self.assertIsNone(DatabaseHandler().connect("database"))
        self.assertIsNone(DatabaseHandler().connect("database.db"))
        self.assertIsNone(DatabaseHandler().connect("asdl;kfjeiei"))
        self.assertIsNone(DatabaseHandler().connect("./honeyDB/honeyDB.sqllite"))
        self.assertIsNone(DatabaseHandler().connect("./honeyDB/honeyDB.db"))
        self.assertIsNone(DatabaseHandler().connect(" "))
        self.assertIsNone(DatabaseHandler().connect(""))

        # Testing for correct DB
        plugin_cfg_path = "tests/reportserver/testconfig/plugins.cfg"
        global_cfg_path = "tests/reportserver/testconfig/global.cfg"
        global_config = GlobalConfig(plugin_cfg_path, global_cfg_path, True)
        global_config.read_global_config()
        global_config.read_plugin_config()
        db = Database()
        db.create_db_dir()
        db.create_db()
        db_path = global_config["Database"]["path"]

        self.assertTrue(sqlite3.connect(db_path))
        self.assertTrue(DatabaseHandler().connect(db_path))
        self.assertTrue(DatabaseHandler().connect(None))
示例#6
0
    def setUp(self):
        # Testing for correct DB
        plugin_cfg_path = "tests/reportserver/testconfig/plugins.cfg"
        global_cfg_path = "tests/reportserver/testconfig/global.cfg"
        global_config = GlobalConfig(plugin_cfg_path, global_cfg_path, True)
        global_config.read_global_config()
        global_config.read_plugin_config()

        conn = sqlite3.connect("TestDB.sqlite")
        c = conn.cursor()
        c.execute("""CREATE TABLE test_http (port int, data text, eventDateTime text)""")
        c.execute("""CREATE TABLE test_http2 (port int, data text, eventDateTime text)""")
        c.execute("""CREATE TABLE test_telnet (port int, data text, eventDateTime text)""")
        test_start_date = datetime.datetime(1999, month=12, day=31, hour=23, minute=59, second=59)

        for x in range(0, 500):
            d = datetime.timedelta(weeks=x)
            insert_date = test_start_date + d
            insert_date_iso = insert_date.isoformat()
            c.execute("INSERT INTO test_http VALUES (8082,'TEXT','%s')" % insert_date_iso)
            c.execute("INSERT INTO test_http2 VALUES (8083,'TEXT','%s')" % insert_date_iso)
            c.execute("INSERT INTO test_telnet VALUES (8023,'TEXT','%s')" % insert_date_iso)

        conn.commit()
示例#7
0
class datavalidator_test(unittest.TestCase):

    def setUp(self):
        self.test_db_dir = '/tests/database/test_database'
        self.test_db_file = '/tests/database/test_database/honeyDB.sqlite'
        # test configuration files
        self.plugins_config_file = 'tests/database/test_config/plugins.cfg'
        self.plugins_config_diff_file = 'tests/database/test_config/plugins_diff.cfg'
        self.plugins_config_diff_table_file = 'tests/database/test_config/plugins_diff_table.cfg'
        self.global_config_file = 'tests/database/test_config/global.cfg'
        # create global config instance
        self.gci = GlobalConfig(self.plugins_config_file,self.global_config_file)
        self.gci.read_global_config()
        self.gci.read_plugin_config()
        self.tables_test = ['p0f', 'ipInfo', 'sessions', 'test_http', 'test_http2', 'test_telnet']
        self.table_schema_test = {'sessions': [(0, 'session', 'TEXT', 0, None, 1),
                                               (1, 'table_name', 'TEXT', 1, None, 2)],
                                  'test_http2': [(0, 'ID', 'INTEGER', 1, None, 1),
                                                 (1, 'session', 'TEXT', 0, None, 0),
                                                 (2, 'eventDateTime', 'TEXT', 0, None, 0),
                                                 (3, 'peerAddress', 'TEXT', 0, None, 0),
                                                 (4, 'localAddress', 'TEXT', 0, None, 0),
                                                 (5, 'command', 'TEXT', 0, None, 0),
                                                 (6, 'path', 'TEXT', 0, None, 0),
                                                 (7, 'headers', 'TEXT', 0, None, 0),
                                                 (8, 'body', 'TEXT', 0, None, 0)],
                                  'test_http': [(0, 'ID', 'INTEGER', 1, None, 1),
                                                (1, 'session', 'TEXT', 0, None, 0),
                                                (2, 'eventDateTime', 'TEXT', 0, None, 0),
                                                (3, 'peerAddress', 'TEXT', 0, None, 0),
                                                (4, 'localAddress', 'TEXT', 0, None, 0),
                                                (5, 'command', 'TEXT', 0, None, 0),
                                                (6, 'path', 'TEXT', 0, None, 0),
                                                (7, 'headers', 'TEXT', 0, None, 0),
                                                (8, 'body', 'TEXT', 0, None, 0)],
                                  'p0f': [(0, 'session', 'TEXT', 1, None, 1),
                                          (1, 'first_seen', 'TEXT', 0, None, 0),
                                          (2, 'last_seen', 'TEXT', 0, None, 0),
                                          (3, 'uptime', 'INTEGER', 0, None, 0),
                                          (4, 'last_nat', 'TEXT', 0, None, 0),
                                          (5, 'last_chg', 'TEXT', 0, None, 0),
                                          (6, 'distance', 'INTEGER', 0, None, 0),
                                          (7, 'bad_sw', 'INTEGER', 0, None, 0),
                                          (8, 'os_name', 'TEXT', 0, None, 0),
                                          (9, 'os_flavor', 'TEXT', 0, None, 0),
                                          (10, 'os_match_q', 'INTEGER', 0, None, 0),
                                          (11, 'http_name', 'TEXT', 0, None, 0),
                                          (12, 'http_flavor', 'TEXT', 0, None, 0),
                                          (13, 'total_conn', 'INTEGER', 0, None, 0),
                                          (14, 'link_type', 'TEXT', 0, None, 0),
                                          (15, 'language', 'TEXT', 0, None, 0)],
                                  'ipInfo': [(0, 'ip', 'TEXT', 1, None, 1),
                                             (1, 'plugin_instance', 'TEXT', 1, None, 2),
                                             (2, 'timestamp', 'TEXT', 1, None, 0),
                                             (3, 'hostname', 'TEXT', 0, None, 0),
                                             (4, 'city', 'TEXT', 0, None, 0),
                                             (5, 'region', 'TEXT', 0, None, 0),
                                             (6, 'country', 'TEXT', 0, None, 0),
                                             (7, 'lat', 'REAL', 0, None, 0),
                                             (8, 'long', 'REAL', 0, None, 0),
                                             (9, 'org', 'TEXT', 0, None, 0),
                                             (10, 'postal', 'TEXT', 0, None, 0)],
                                  'test_telnet': [(0, 'ID', 'INTEGER', 1, None, 1),
                                                  (1, 'session', 'TEXT', 0, None, 0),
                                                  (2, 'eventDateTime', 'TEXT', 0, None, 0),
                                                  (3, 'peerAddress', 'TEXT', 0, None, 0),
                                                  (4, 'localAddress', 'TEXT', 0, None, 0),
                                                  (5, 'input_type', 'TEXT', 0, None, 0),
                                                  (6, 'user_input', 'TEXT', 0, None, 0)]}

    @patch.object(Logger,'__new__')
    def test_get_schema_from_database(self,log):
        # will call the constructor because this calls this method
        db = database.Database()
        db.create_default_database()
        validator = datavalidator.DataValidator()
        self.assertTrue(set(validator.tables) == set(self.tables_test))
        self.assertEqual(validator.table_schema,self.table_schema_test)
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_get_tables(self,log):
        db = database.Database()
        db.create_default_database()
        validator = datavalidator.DataValidator()
        self.assertIsInstance(validator.get_tables(),list)
        self.assertTrue(set(validator.get_tables()) == set(self.tables_test))
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_get_schema(self,log):
        db = database.Database()
        db.create_default_database()
        validator = datavalidator.DataValidator()
        self.assertIsInstance(validator.get_schema(),dict)
        self.assertEqual(validator.get_schema(),self.table_schema_test)
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_check_value_len(self, log):
        good_dict = {'table1':{'col1':'val1','col2':'val2'}}
        bad_dict = {'table1':{'col1':'val1','col2':'val2'},'table2':{'col3':'val3','col4':'val4'}}
        db = database.Database()
        db.create_default_database()
        validator = datavalidator.DataValidator()
        log.error = unittest.mock.Mock()
        log.error.reset_mock()
        self.assertTrue(validator.check_value_len(good_dict))
        self.assertFalse(db.log.error.called)
        self.assertFalse(validator.check_value_len(bad_dict))
        self.assertTrue(db.log.error.called)
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_check_value_is_dict(self,log):
        good_dict = {'table1':{'col1':'val1','col2':'val2'}}
        bad_dict = ['im a list']
        db = database.Database()
        db.create_default_database()
        validator = datavalidator.DataValidator()
        log.error = unittest.mock.Mock()
        log.error.reset_mock()
        self.assertTrue(validator.check_value_is_dict(good_dict))
        self.assertFalse(db.log.error.called)
        self.assertFalse(validator.check_value_is_dict(bad_dict))
        self.assertTrue(db.log.error.called)
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_check_key_in_dict_string(self,log):
        good_dict = {'table1':{'col1':'val1','col2':'val2'}}
        bad_dict = {1:{'col1':'val1'}}
        db = database.Database()
        db.create_default_database()
        validator = datavalidator.DataValidator()
        log.error = unittest.mock.Mock()
        log.error.reset_mock()
        self.assertTrue(validator.check_key_in_dict_string(good_dict))
        self.assertFalse(db.log.error.called)
        self.assertFalse(validator.check_key_in_dict_string(bad_dict))
        self.assertTrue(db.log.error.called)
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_check_key_is_valid_table_name(self,log):
        good_dict = {'test_http':{'col1':'val1','col2':'val2'}}
        bad_dict = {'test_table':{'col1':'val1'}}
        db = database.Database()
        db.create_default_database()
        validator = datavalidator.DataValidator()
        log.error = unittest.mock.Mock()
        log.error.reset_mock()
        self.assertTrue(validator.check_key_is_valid_table_name(good_dict))
        self.assertFalse(db.log.error.called)
        self.assertFalse(validator.check_key_is_valid_table_name(bad_dict))
        self.assertTrue(db.log.error.called)
        shutil.rmtree(os.getcwd() + self.test_db_dir)


    @patch.object(Logger,'__new__')
    def test_check_row_value_is_dict(self,log):
        good_dict = {'test_table':{'col1':'val1','col2':'val2'}}
        bad_dict = {'test_table':'i am not a dictionary'}
        db = database.Database()
        db.create_default_database()
        validator = datavalidator.DataValidator()
        log.error = unittest.mock.Mock()
        log.error.reset_mock()
        self.assertTrue(validator.check_row_value_is_dict(good_dict))
        self.assertFalse(db.log.error.called)
        self.assertFalse(validator.check_row_value_is_dict(bad_dict))
        self.assertTrue(db.log.error.called)
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_check_all_col_names_strings(self,log):
        good_dict = {'test_table':{'col1':'val1','col2':'val2'}}
        bad_dict = {'test_table':{1:'val1','string':'val2'}}
        db = database.Database()
        db.create_default_database()
        validator = datavalidator.DataValidator()
        self.assertTrue(validator.check_all_col_names_strings(good_dict))
        self.assertFalse(validator.check_all_col_names_strings(bad_dict))
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_check_all_col_exist(self,log):
        good_dict = {'test_http':{'session': 'val1','eventDateTime': 'val2','peerAddress': 'val3',
                                  'localAddress': 'val4','command': 'val5','path': 'val6',
                                  'headers': 'val7','body': 'val8'}}
        bad_dict = {'test_http2':{'session': 'val1','eventDateTime': 'val2','peerAddress': 'val3',
                                  'XYZ': 'val4','command': 'val5','path': 'val6',
                                  'headers': 'val7','body': 'val8'}}
        missing_col = {'test_http':{'session': 'val1','eventDateTime': 'val2','peerAddress': 'val3',
                                    'localAddress': 'val4', 'path': 'val6', 'headers': 'val7', 'body': 'val8'}}
        db = database.Database()
        db.create_default_database()
        validator = datavalidator.DataValidator()
        log.error = unittest.mock.Mock()
        log.error.reset_mock()
        self.assertTrue(validator.check_all_col_exist(good_dict))
        self.assertFalse(db.log.error.called)
        self.assertTrue(validator.check_all_col_exist(missing_col))
        self.assertFalse(db.log.error.called)
        self.assertFalse(validator.check_all_col_exist(bad_dict))
        self.assertTrue(db.log.error.called)
        shutil.rmtree(os.getcwd() + self.test_db_dir)
示例#8
0
class WorldmapServiceHandler():
    def __init__(self):
        self.log = Logger().get('reportserver.manager.WorldmapServiceManager.py')
        self.global_config = GlobalConfig()
        self.global_config.read_plugin_config()
        self.global_config.read_global_config()

    def process(self, rqst, path_tokens, query_tokens):
        global have_basemap
        if not have_basemap:
            err_msg = \
                ('<html><head><title>WorldMap</title></head><body>'
                'To enable WorldMap generation, please visit '
                '<a href="https://recce7.github.io/">the documentation</a> and '
                'follow the directions for installing the Basemap library.'
                '</body></html>')
            rqst.send_response(200)

            #todo make this configurable for allow-origin
            rqst.send_header("Access-Control-Allow-Origin","http://localhost:8000")
            rqst.send_header('Content-Type', 'text/html')
            rqst.send_header('Content-Length', len(err_msg))
            rqst.end_headers()
            rqst.flush_headers()

            rqst.wfile.write(bytes(err_msg, "utf-8"))

            rqst.wfile.flush()

            return

        uom = None
        units = None
        self.log.info("processing ipaddress request:" + str(path_tokens) + str(query_tokens))


        try:
            time_period = utilities.validate_time_period(query_tokens)
            uom = time_period[0]
            units = time_period[1]
        except ValueError:
            rqst.badRequest(units)
            return


        if len(path_tokens) >= 5:
            rqst.badRequest()
            return
        else:
            self.construct_worldmap(rqst, uom, units)

    def construct_worldmap(self, rqst, uom, units):
        #call to construct port list
        #find unique ips by port
        #merge the results togoether
        #build the map
        #probably want to look at the PortsServiceHandler.py or IpsServiceHandler.py to follow those patterns.
        ip_map = pickle.loads(pickle_bytes)

        pts = self.get_point_list(uom, units)
        for pt in pts:
            srclat, srclong = pt
            x, y = ip_map(srclong, srclat)
            plt.plot(x, y, 'o', color='#ff0000', ms=2.7, markeredgewidth=1.0)

        plt.savefig('reportserver/worldmap.png', dpi=600)

        img = Image.open('reportserver/worldmap.png')
        draw = ImageDraw.Draw(img)

        font = ImageFont.truetype(
            "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 175)
        draw.text((50, 50), "Unique IP addresses: last %s %s" % (units, uom),
                  (0, 0, 0), font=font)

        font = ImageFont.truetype(
            "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 125)
        draw.text((50, 325), "Total: %s" % (len(pts)),
                  (0, 0, 0), font=font)

        # draw = ImageDraw.Draw(img)
        # draw = ImageDraw.Draw(img)
        img.save("reportserver/worldmap.png")

        rqst.sendPngResponse("reportserver/worldmap.png", 200)

    def get_point_list(self, uom, units):
        begin_date = dateTimeUtility.get_begin_date_iso(uom, units)
        query_string = ('select lat,long '
                        'from ('
                            'select distinct lat,long,timestamp, ip '
                            'from ipInfo '
                            'where lat is not null '
                            'and long is not null '
                            'and datetime(timestamp) > datetime(\'' + begin_date + '\')'
                            ');')
        connection = sqlite3.connect(self.global_config['Database']['path'])
        cursor = connection.cursor()
        return cursor.execute(query_string).fetchall()
示例#9
0
class GlobalConfig_Test(unittest.TestCase):
    def setUp(self):

        self.gconfig = GlobalConfig(test_cfg_path, test_global_cfg_path, refresh=True)
        self.gconfig.read_plugin_config()
        self.gconfig.read_global_config()

    def test_getInstance(self):

        gconfig2 = GlobalConfig()

        self.assertEqual(str(self.gconfig),str(gconfig2),"these 2 objects should equal")

        gconfig3 = GlobalConfig()

        self.assertEqual(str(self.gconfig), str(gconfig3), "these 2 objects should equal")
        self.assertEqual(str(gconfig2), str(gconfig3), "these 2 objects should equal")


    def test_getPorts(self):

        ports = self.gconfig.get_ports()

        self.assertEqual(len(ports), 2, "expected 2 ports in test.cfg found: " + str(len(ports)))

        for port in ports:
            print("found: " + str(port))


    def test_getReportServerConfig(self):
        host = self.gconfig.get_report_server_host()
        port = self.gconfig.get_report_server_port()
        self.assertEqual(host, "", "expected host to be ''")
        self.assertEqual(port, 8080, "expected port to be '8080' ")

    def test_getReportServerHost(self):
        self.assertEqual("", self.gconfig.get_report_server_host())

    def test_getReportServerPort(self):
        self.assertEqual(8080, self.gconfig.get_report_server_port())

    def test_refresh_instance(self):

        gconfig2 = GlobalConfig(test_cfg_path, test_global_cfg_path, refresh=True)
        self.assertNotEqual(str(self.gconfig), str(gconfig2), "these 2 objects should NOT equal when refresh set to True")

    def test_refresh_instance_same(self):
        gconfig2 = GlobalConfig()
        self.assertEqual(str(self.gconfig), str(gconfig2), "these 2 objects should equal when False is set for Refresh")

        gconfig2 = GlobalConfig()
        self.assertEqual(str(self.gconfig), str(gconfig2), "these 2 objects should equal with default of False")

    def test_get_date_time_name(self):
        self.assertEqual("eventDateTime",
                         self.gconfig['Database']['datetime.name'])

    def test_get_db_peerAddress_nameself(self):
        self.assertEqual("peerAddress",
                         self.gconfig['Database']['peerAddress.name'])

    def test_get_db_localAddress_name(self):
        self.assertEqual("localAddress",
                         self.gconfig['Database']['localAddress.name'])
示例#10
0
文件: frmwork.py 项目: RECCE7/recce7
    class _Framework:
        def __init__(self, plugin_cfg_path, global_cfg_path):
            self._global_config = GlobalConfig(plugin_cfg_path, global_cfg_path)
            self._plugin_imports = {}
            self._listener_list= {}
            self._running_plugins_list = []
            self._data_manager = None
            self._shutting_down = False
            self._log = None
            self._pid = os.getpid()

        def start(self):
            self.set_shutdown_hook()
            print('Press Ctrl+C to exit.')
            if not self.drop_permissions():
                return

            self._global_config.read_global_config()

            self.start_logging()

            self._global_config.read_plugin_config()
            self._data_manager = DataManager()
            self._data_manager.start()

            self.start_listeners()

        def start_logging(self):
            log_path = self._global_config['Framework']['logName']
            log_level = self._global_config['Framework']['logLevel']
            self._log = Logger(log_path, log_level).get('framework.frmwork.Framework')
            self._log.info('RECCE7 started (PID %d)' % self._pid)

        @staticmethod
        def drop_permissions():
            if os.getuid() != 0:
                return True

            dist_name = os.getenv('RECCE7_OS_DIST')
            users_dict = {
                'centos': ('nobody', 'nobody'),
                'debian': ('nobody', 'nogroup')
            }
            if dist_name not in users_dict:
                print(
                    'Unable to lower permission level - not continuing as\n'
                    'superuser. Please set the environment variable\n'
                    'RECCE7_OS_DIST to one of:\n\tcentos\n\tdebian\n'
                    'or rerun as a non-superuser.')
                return False
            lowperm_user = users_dict[dist_name]
            nobody_uid = pwd.getpwnam(lowperm_user[0]).pw_uid
            nogroup_gid = grp.getgrnam(lowperm_user[1]).gr_gid

            os.setgroups([])
            os.setgid(nogroup_gid)
            os.setuid(nobody_uid)
            os.umask(0o077)

            return True

        def create_import_entry(self, port, name, clsname):
            imp = import_module('plugins.' + name)
            self._plugin_imports[port] = getattr(imp, clsname)

        def start_listeners(self):
            ports = self._global_config.get_ports()
            for port in ports:
                plugin_config = self._global_config.get_plugin_config(port)
                module = plugin_config['module']
                clsname = plugin_config['moduleClass']
                self.create_import_entry(port, module, clsname)

                address = self._global_config['Framework']['listeningAddress']
                listener = NetworkListener(address, plugin_config, self)
                listener.start()
                self._listener_list[port] = listener

        def set_shutdown_hook(self):
            signal.signal(signal.SIGINT, self.shutdown)

        def shutdown(self, *args):
            self._shutting_down = True

            self._log.debug('Shutting down network listeners')
            for listener in self._listener_list.values():
                listener.shutdown()

            self._log.debug('Shutting down plugins')
            for plugin in self._running_plugins_list:
                plugin.shutdown()

            self._log.debug('Shutting down data manager')
            self._data_manager.shutdown()

            print('Goodbye!')

        #
        # Framework API
        #

        def get_config(self, port):
            """
            Returns the configuration dictionary for the plugin
            running on the specified port.

            :param port: a port number associated with a loaded plugin
            :return: a plugin configuration dictionary
            """
            return self._global_config.get_plugin_config(port)

        def spawn(self, socket, config):
            """
            Spawns the plugin configured by 'config' with the provided
            (accepted) socket.

            :param socket: an open, accepted socket returned by
                           socket.accept()
            :param config: the plugin configuration dictionary describing
                           the plugin to spawn
            :return: a reference to the plugin that was spawned
            """
            # ToDo Throw exception if plugin class not found
            plugin_class = self._plugin_imports[config['port']]
            plugin = plugin_class(socket, config, self)
            plugin.start()
            self._running_plugins_list.append(plugin)
            return plugin

        def insert_data(self, data):
            """
            Inserts the provided data into the data queue so that it can
            be pushed to the database.

            :param data: data object to add to the database
            """
            self._data_manager.insert_data(data)

        def plugin_stopped(self, plugin):
            """
            Tells the framework that the specified plugin has stopped
            running and doesn't need to be shutdown explicitly on program
            exit.

            :param plugin: a reference to a plugin
            """
            if self._shutting_down:
                return

            self._running_plugins_list.remove(plugin)
示例#11
0
class database_test(unittest.TestCase):

    def setUp(self):
        self.test_db_dir = '/tests/database/test_database'
        self.test_db_file = '/tests/database/test_database/honeyDB.sqlite'
        # test configuration files
        self.plugins_config_file = 'tests/database/test_config/plugins.cfg'
        self.plugins_config_diff_file = 'tests/database/test_config/plugins_diff.cfg'
        self.plugins_config_diff_table_file = 'tests/database/test_config/plugins_diff_table.cfg'
        self.global_config_file = 'tests/database/test_config/global.cfg'
        # create global config instance
        self.gci = GlobalConfig(self.plugins_config_file,self.global_config_file)
        self.gci.read_global_config()
        self.gci.read_plugin_config()

    # patch the Logger new method so that it doesn't create the log file, we do not need to test that the logging
    # works just that the log statements are called.
    @patch.object(Logger,'__new__')
    def test_database_init(self,log):
        db = database.Database()
        self.assertIsInstance(db.global_config,GlobalConfig._GlobalConfig)
        self.assertTrue(log.called)

    @patch.object(Logger,'__new__')
    def test_database_create_default_database(self,log):
        db = database.Database()
        db.create_default_database()
        validator = datavalidator.DataValidator()
        # check that the directory exists
        self.assertTrue(os.path.isdir(os.getcwd() + self.test_db_dir))
        # check that the database file exists
        self.assertTrue(os.path.isfile(os.getcwd() + self.test_db_file))
        # get the table names from the database
        schema_table_list = validator.get_tables()
        # get the user defined tables from the configuration file
        config_table_list = util.get_config_table_list(self.gci.get_ports(),
                                                       self.gci.get_plugin_dictionary())
        # check that the non user defined table p0f exists
        self.assertTrue('p0f' in schema_table_list)
        # check that the non user defined table ipInfo exists
        self.assertTrue('ipInfo' in schema_table_list)
        # check that the non user defined table sessions exists
        self.assertTrue('sessions' in schema_table_list)
        # check that the user defined tables are a subset of the tables in the database schema
        self.assertTrue(set(config_table_list) < set(schema_table_list))
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_database_create_db_dir(self,log):
        db = database.Database()
        db.create_db_dir()
        self.assertTrue(os.path.isdir(os.getcwd() + self.test_db_dir))
        self.assertTrue(log.called)
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_database_create_db_dir_already_exists(self,log):
        os.mkdir(os.getcwd() + self.test_db_dir)
        self.assertTrue(os.path.isdir(os.getcwd() + self.test_db_dir))
        db = database.Database()
        log.reset_mock()
        db.create_db_dir()
        self.assertFalse(log.called)
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_database_create_db(self,log):
        db = database.Database()
        db.create_db_dir()
        db.create_db()
        self.assertTrue(os.path.isfile(os.getcwd() + self.test_db_file))
        self.assertTrue(log.called)
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_database_update_schema(self,log):
        db = database.Database()
        db.create_db_dir()
        db.create_db()
        util.run_db_scripts(self.gci)
        db.update_schema()
        validator = datavalidator.DataValidator()
        schema = validator.get_schema()

        self.assertTrue(schema['test_http'][5][1] == 'command')
        self.assertTrue(schema['test_http2'][6][1] == 'path')
        self.assertTrue(len(schema['test_telnet']) == 7)

        # set global config instance to the differing column config file
        self.gci = GlobalConfig(self.plugins_config_diff_file,self.global_config_file,True)
        self.gci.read_global_config()
        self.gci.read_plugin_config()

        db2 = database.Database()
        db2.update_schema()
        validator2 = datavalidator.DataValidator()

        schema2 = validator2.get_schema()

        self.assertTrue(schema2['test_http'][5][1] == 'unit_test_data_1')
        self.assertTrue(schema2['test_http2'][6][1] == 'unit_test_data_2')
        self.assertTrue(len(schema2['test_telnet']) == 8)
        self.assertTrue(schema2['test_telnet'][7][1] == 'unit_test_data_3')

        # set global config instance back to normal
        self.gci = GlobalConfig(self.plugins_config_file,self.global_config_file,True)
        self.gci.read_global_config()
        self.gci.read_plugin_config()

        shutil.rmtree(os.getcwd() + self.test_db_dir)
示例#12
0
class dataqueue_test(unittest.TestCase):

    def setUp(self):
        self.test_db_dir = '/tests/database/test_database'
        self.test_db_file = '/tests/database/test_database/honeyDB.sqlite'
        # test configuration files
        self.plugins_config_file = 'tests/database/test_config/plugins.cfg'
        self.plugins_config_diff_file = 'tests/database/test_config/plugins_diff.cfg'
        self.plugins_config_diff_table_file = 'tests/database/test_config/plugins_diff_table.cfg'
        self.global_config_file = 'tests/database/test_config/global.cfg'
        # create global config instance
        self.gci = GlobalConfig(self.plugins_config_file,self.global_config_file)
        self.gci.read_global_config()
        self.gci.read_plugin_config()

    @patch.object(Logger,'__new__')
    def test_dataqueue_init(self,log):
        db = database.Database()
        db.create_default_database()
        dq = dataqueue.DataQueue()
        self.assertIsInstance(dq.dataQueue, queue.Queue)
        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_insert_into_dataqueue(self,log):
        insert_dict = {'test_telnet':{'session':'abc',
                                      'eventDateTime':'02-03-2016 04:15:37.037',
                                      'peerAddress':'41.26.134.3',
                                      'localAddress':'192.168.0.51',
                                      'input_type':'test',
                                      'user_input':'test input'}}
        insert_dict2 = {'test_telnet':{'session':'abcd',
                                      'eventDateTime':'02-04-2016 04:15:37.037',
                                      'peerAddress':'41.27.234.3',
                                      'localAddress':'192.168.0.42',
                                      'input_type':'test1',
                                      'user_input':'test input1'}}
        db = database.Database()
        db.create_default_database()
        dq = dataqueue.DataQueue()
        self.assertTrue(dq.check_empty())
        dq.insert_into_data_queue(insert_dict)
        self.assertFalse(dq.check_empty())
        dq.insert_into_data_queue(insert_dict2)
        self.assertTrue(dq.dataQueue.qsize() == 2)
        result = dq.get_next_item()
        self.assertTrue(insert_dict, result)
        self.assertTrue(dq.dataQueue.qsize() == 1)
        result2 = dq.get_next_item()
        self.assertTrue(insert_dict2, result2)
        self.assertTrue(dq.check_empty())

        shutil.rmtree(os.getcwd() + self.test_db_dir)

    @patch.object(Logger,'__new__')
    def test_insert_into_dataqueue_bad_value(self,log):
        insert_dict = {'test_telnet':{'session':'abc',
                                      'eventDateTime':'02-03-2016 04:15:37.037',
                                      'peerAddress':'41.26.134.3',
                                      'localAddress':'192.168.0.51',
                                      'input_type':'test',
                                      'user_input':'test input',
                                      'fake_column':'bad data'}}
        db = database.Database()
        db.create_default_database()
        dq = dataqueue.DataQueue()
        self.assertFalse(dq.insert_into_data_queue(insert_dict))

        shutil.rmtree(os.getcwd() + self.test_db_dir)