コード例 #1
0
def check_enviorment_variable(database_key):
    if not database_key == 'aggregate-database':
        database_path = face_databases_path_correspondences[database_key]
        if not DatabasesPathChecker.check_if_environment_is_defined_for(database_path):
            raise EnvironmentError(
                "{} must be set in order to run a experiment on the {} database".format(database_path, database_key))
    else:
        for subdatabase_key in ['casia-fasd', 'casia-surf', 'csmad', 'hkbu', 'msu-mfsd', 'oulu-npu', 'replay-attack',
                                'replay-mobile', 'rose-youtu', 'siw', '3dmad', 'uvad']:
            subdatabase_path = face_databases_path_correspondences[subdatabase_key]
            if not DatabasesPathChecker.check_if_environment_is_defined_for(subdatabase_path):
                raise EnvironmentError(
                    "{} must be set in order to run a experiment on the {} database".format(subdatabase_path,
                                                                                            database_key))
コード例 #2
0
class UnitTest3dmadDatabase(unittest.TestCase):
    skip = not DatabasesPathChecker.check_if_environment_is_defined_for(
        "THREEDMAD_PATH")
    reason = "THREEDMAD_PATH has not been found. Impossible to run these tests"

    def setUp(self):
        if not self.skip:
            self.database = ThreedmadDatabase(os.environ['THREEDMAD_PATH'])
        else:
            self.database = ThreedmadDatabase('/home')  # Dummy path

        self.available_common_pais = [0, 8]
        self.available_common_capture_devices = [0]
        self.available_common_lightning = [0]
        self.available_common_face_resolution = [0]

    def test_constructor_with_non_existing_path(self):
        self.assertRaises(IOError, lambda: ThreedmadDatabase('wrong_path'))

    def test_name_static_method(self):
        self.assertEqual(ThreedmadDatabase.name(), '3dmad')

    def test_is_a_collection_of_databases_static_method(self):
        self.assertFalse(ThreedmadDatabase.is_a_collection_of_databases())

    @unittest.skipIf(skip, reason)
    def test_get_all_accesses(self):
        dict_all_accesses = self.database.get_all_accesses()

        self.assertEqual(len(dict_all_accesses['Train']), 105)
        self.assertEqual(len(dict_all_accesses['Dev']), 75)
        self.assertEqual(len(dict_all_accesses['Test']), 75)

    def test_get_all_labels(self):
        dict_all_labels = self.database.get_all_labels()

        self.assertEqual(len(dict_all_labels['Train']), 105)
        self.assertEqual(len(dict_all_labels['Dev']), 75)
        self.assertEqual(len(dict_all_labels['Test']), 75)

    def test_get_ground_truth_protocol_grandtest(self):
        dict_ground_truth = self.database.get_ground_truth('grandtest')

        self.assertEqual(len(dict_ground_truth['Train']), 105)
        self.assertEqual(len(dict_ground_truth['Dev']), 75)
        self.assertEqual(len(dict_ground_truth['Test']), 75)

    def test_common_labels_are_ok(self):
        dict_all_labels = self.database.get_all_labels()

        for subset, subset_dict in dict_all_labels.items():
            for basename, labels_dict in subset_dict.items():
                self.assertIn(labels_dict['common_pai'],
                              self.available_common_pais)
                self.assertIn(labels_dict['common_capture_device'],
                              self.available_common_capture_devices)
                self.assertIn(labels_dict['common_lightning'],
                              self.available_common_lightning)
                self.assertIn(labels_dict['common_face_resolution'],
                              self.available_common_face_resolution)
コード例 #3
0
class UnitTestOuluNpuDatabase(unittest.TestCase):
    skip = not DatabasesPathChecker.check_if_environment_is_defined_for(
        "OULU_NPU_PATH")
    reason = "OULU_NPU_PATH has not been found. Impossible to run these tests"

    def setUp(self):
        if not self.skip:
            self.database = OuluNpuDatabase(os.environ['OULU_NPU_PATH'])
            self.available_common_pais = [0, 3, 5, 6]
            self.available_common_capture_devices = [2, 3]
            self.available_common_lightning = [0, 1]
            self.available_common_face_resolution = [0, 1, 2]

    def test_constructor_with_non_existing_path(self):
        self.assertRaises(IOError, lambda: OuluNpuDatabase('wrong_path'))

    def test_name_static_method(self):
        self.assertEqual(OuluNpuDatabase.name(), 'oulu-npu')

    def test_is_a_collection_of_databases_static_method(self):
        self.assertFalse(OuluNpuDatabase.is_a_collection_of_databases())

    @unittest.skipIf(skip, reason)
    def test_get_all_accesses(self):
        dict_all_accesses = self.database.get_all_accesses()

        self.assertEqual(len(dict_all_accesses['Train']), 1800)
        self.assertEqual(len(dict_all_accesses['Dev']), 1350)
        self.assertEqual(len(dict_all_accesses['Test']), 1800)

    @unittest.skipIf(skip, reason)
    def test_get_all_labels(self):
        dict_labels = self.database.get_all_labels()

        self.assertEqual(len(dict_labels['Train']), 1800)
        self.assertEqual(len(dict_labels['Dev']), 1350)
        self.assertEqual(len(dict_labels['Test']), 1800)

    @unittest.skipIf(skip, reason)
    def test_get_ground_truth_protocol_grandtest(self):
        dict_ground_truth = self.database.get_ground_truth('grandtest')

        self.assertEqual(len(dict_ground_truth['Train']), 1800)
        self.assertEqual(len(dict_ground_truth['Dev']), 1350)
        self.assertEqual(len(dict_ground_truth['Test']), 1800)

    @unittest.skipIf(skip, reason)
    def test_get_ground_truth_protocol_1(self):
        dict_ground_truth = self.database.get_ground_truth('Protocol_1')

        self.assertEqual(len(dict_ground_truth['Train']), 1200)
        self.assertEqual(len(dict_ground_truth['Dev']), 900)
        self.assertEqual(len(dict_ground_truth['Test']), 600)

    @unittest.skipIf(skip, reason)
    def test_get_ground_truth_protocol_2(self):
        dict_ground_truth = self.database.get_ground_truth('Protocol_2')

        self.assertEqual(len(dict_ground_truth['Train']), 1080)
        self.assertEqual(len(dict_ground_truth['Dev']), 810)
        self.assertEqual(len(dict_ground_truth['Test']), 1080)

    @unittest.skipIf(skip, reason)
    def test_get_ground_truth_protocol_3(self):
        for i in range(6):
            protocol_tag = 'Protocol_3_' + str(i + 1)

            dict_ground_truth = self.database.get_ground_truth(protocol_tag)

            self.assertEqual(len(dict_ground_truth['Train']), 1500)
            self.assertEqual(len(dict_ground_truth['Dev']), 1125)
            self.assertEqual(len(dict_ground_truth['Test']), 300)

    @unittest.skipIf(skip, reason)
    def test_get_ground_truth_protocol_4(self):
        for i in range(6):
            protocol_tag = 'Protocol_4_' + str(i + 1)
            dict_ground_truth = self.database.get_ground_truth(protocol_tag)
            self.assertEqual(len(dict_ground_truth['Train']), 600)
            self.assertEqual(len(dict_ground_truth['Dev']), 450)
            self.assertEqual(len(dict_ground_truth['Test']), 60)

    @unittest.skipIf(skip, reason)
    def test_get_ground_truth_protocol_4_no_loco(self):
        dict_ground_truth = self.database.get_ground_truth(
            'Protocol_4_no_loco')

        self.assertEqual(len(dict_ground_truth['Train']), 720)
        self.assertEqual(len(dict_ground_truth['Dev']), 540)
        self.assertEqual(len(dict_ground_truth['Test']), 360)

    @unittest.skipIf(skip, reason)
    def test_common_labels_are_ok(self):
        dict_all_labels = self.database.get_all_labels()

        for subset, subset_dict in dict_all_labels.items():
            for basename, labels_dict in subset_dict.items():
                self.assertIn(labels_dict['common_pai'],
                              self.available_common_pais)
                self.assertIn(labels_dict['common_capture_device'],
                              self.available_common_capture_devices)
                self.assertIn(labels_dict['common_lightning'],
                              self.available_common_lightning)
                self.assertIn(labels_dict['common_face_resolution'],
                              self.available_common_face_resolution)
class UnitTestReplayMobileDatabase(unittest.TestCase):
    skip = not DatabasesPathChecker.check_if_environment_is_defined_for(
        "REPLAY_MOBILE_PATH")
    reason = "REPLAY_MOBILE_PATH has not been found. Impossible to run these tests"

    def setUp(self):
        if not self.skip:
            self.database = ReplayMobileDatabase(
                os.environ['REPLAY_MOBILE_PATH'])
        else:
            self.database = ReplayMobileDatabase('/home')  # Dummy path
        self.available_common_pais = [0, 1, 6]
        self.available_common_capture_devices = [2, 3]
        self.available_common_lightning = [0, 1]
        self.available_common_face_resolution = [0, 1, 2]

    def test_constructor_with_non_existing_path(self):
        self.assertRaises(IOError, lambda: ReplayMobileDatabase('wrong_path'))

    def test_name_static_method(self):
        self.assertEqual(ReplayMobileDatabase.name(), 'replay-mobile')

    def test_is_a_collection_of_databases_static_method(self):
        self.assertFalse(ReplayMobileDatabase.is_a_collection_of_databases())

    @unittest.skipIf(skip, reason)
    def test_get_all_accesses(self):
        dict_all_accesses = self.database.get_all_accesses()

        self.assertEqual(len(dict_all_accesses['Train']), 312)
        self.assertEqual(len(dict_all_accesses['Dev']), 416)
        self.assertEqual(len(dict_all_accesses['Test']), 302)

    @unittest.skipIf(skip, reason)
    def test_get_enrolment_accesses(self):
        list_enrolment_accesses = self.database.get_enrolment_access()

        self.assertEqual(len(list_enrolment_accesses), 160)

    def test_get_all_labels(self):
        dict_all_accesses = self.database.get_all_labels()

        self.assertEqual(len(dict_all_accesses['Train']), 312)
        self.assertEqual(len(dict_all_accesses['Dev']), 416)
        self.assertEqual(len(dict_all_accesses['Test']), 302)

    def test_get_ground_truth_protocol_grandtest(self):
        dict_ground_truth = self.database.get_ground_truth('grandtest')

        self.assertEqual(len(dict_ground_truth['Train']), 312)
        self.assertEqual(len(dict_ground_truth['Dev']), 416)
        self.assertEqual(len(dict_ground_truth['Test']), 302)

    def test_common_labels_are_ok(self):
        dict_all_labels = self.database.get_all_labels()
        for subset, subset_dict in dict_all_labels.items():
            for basename, labels_dict in subset_dict.items():
                self.assertIn(labels_dict['common_pai'],
                              self.available_common_pais)
                self.assertIn(labels_dict['common_capture_device'],
                              self.available_common_capture_devices)
                self.assertIn(labels_dict['common_lightning'],
                              self.available_common_lightning)
                self.assertIn(labels_dict['common_face_resolution'],
                              self.available_common_face_resolution)
コード例 #5
0
class UnitTestAggregateDatabase(unittest.TestCase):
    skip = not DatabasesPathChecker.check_if_environment_is_defined_for("CSMAD_PATH") and \
           not DatabasesPathChecker.check_if_environment_is_defined_for("CASIA_FASD_PATH") and \
           not DatabasesPathChecker.check_if_environment_is_defined_for("CASIA_SURF_PATH") and \
           not DatabasesPathChecker.check_if_environment_is_defined_for("HKBU_PATH") and \
           not DatabasesPathChecker.check_if_environment_is_defined_for("MSU_MFSD_PATH") and \
           not DatabasesPathChecker.check_if_environment_is_defined_for("OULU_NPU_PATH") and \
           not DatabasesPathChecker.check_if_environment_is_defined_for("REPLAY_ATTACK_PATH") and \
           not DatabasesPathChecker.check_if_environment_is_defined_for("REPLAY_MOBILE_PATH") and \
           not DatabasesPathChecker.check_if_environment_is_defined_for("ROSE_YOUTU_PATH") and \
           not DatabasesPathChecker.check_if_environment_is_defined_for("SIW_PATH") and \
           not DatabasesPathChecker.check_if_environment_is_defined_for("THREEDMAD_PATH") and \
           not DatabasesPathChecker.check_if_environment_is_defined_for("UVAD_PATH")

    reason = "REPLAY_ATTACK_PATH, REPLAY_MOBILE_PATH, MSU_MFSD_PATH, HKBU_PATH,OULU_NPU_PATH, ROSE_YOUTU_PATH, SIW_PATH," \
             "THREEDMAD_PATH, CASIA_FASD_PATH, CASIA_SURF_PATH, CSMAD_PATH, UVAD_PATH have not been found. Impossible to run these tests "

    def setUp(self):
        if not self.skip:
            self.base_paths = {
                '3dmad': os.environ['THREEDMAD_PATH'],
                'casia-fasd': os.environ['CASIA_FASD_PATH'],
                'casia-surf': os.environ['CASIA_SURF_PATH'],
                'csmad': os.environ['CSMAD_PATH'],
                'hkbu': os.environ['HKBU_PATH'],
                'msu-mfsd': os.environ['MSU_MFSD_PATH'],
                'oulu-npu': os.environ['OULU_NPU_PATH'],
                'replay-attack': os.environ['REPLAY_ATTACK_PATH'],
                'replay-mobile': os.environ['REPLAY_MOBILE_PATH'],
                'rose-youtu': os.environ['ROSE_YOUTU_PATH'],
                'siw': os.environ['SIW_PATH'],
                'uvad': os.environ['UVAD_PATH'],
            }

    def test_constructor_with_non_existing_path(self):
        self.assertRaises(IOError, lambda: AggregateDatabase('wrong_path'))

    def test_name_static_method(self):
        self.assertEqual(AggregateDatabase.name(), 'aggregate-database')

    def test_is_a_collection_of_databases_static_method(self):
        self.assertTrue(AggregateDatabase.is_a_collection_of_databases())

    @unittest.skipIf(skip, reason)
    def test_should_add_new_custom_protocol_correctly(self):
        database = AggregateDatabase(self.base_paths)
        database.set_new_custom_protocol(
            {"new_protocol": {
                "Train": None,
                "Test": None,
                "Dev": None
            }})
        self.assertTrue("new_protocol" in database.protocols)

    @unittest.skipIf(skip, reason)
    def test_get_all_accesses(self):
        database = AggregateDatabase(self.base_paths)
        dict_all_accesses = database.get_all_accesses()

        self.assertEqual(len(dict_all_accesses['All']), 28134)

    @unittest.skipIf(skip, reason)
    def test_get_all_labels(self):
        database = AggregateDatabase(self.base_paths)
        dict_all_labels = database.get_all_labels()
        grandtest_results = {'Train': 11125, 'Dev': 4215, 'Test': 12794}

        for subset, subset_labels in dict_all_labels.items():
            number_accesses_per_subset = 0
            for db, db_labels_dict in subset_labels.items():
                number_accesses_per_subset += len(db_labels_dict)
            self.assertEqual(grandtest_results[subset],
                             number_accesses_per_subset)

    @unittest.skipIf(skip, reason)
    def test_get_ground_truth_protocol_grandtest(self):
        database = AggregateDatabase(self.base_paths)
        dict_ground_truth = database.get_ground_truth('grandtest')

        self.assertEqual(len(dict_ground_truth['Train']), 11125)
        self.assertEqual(len(dict_ground_truth['Dev']), 4215)
        self.assertEqual(len(dict_ground_truth['Test']), 12794)