Пример #1
0
    def test_get_ground_truth_protocol_grandtest(self):
        database = AggregateDatabase(self.base_paths)
        dict_ground_truth = database.get_ground_truth('grandtest')

        self.assertEqual(len(dict_ground_truth['Train']), 11125)
        self.assertEqual(len(dict_ground_truth['Dev']), 4215)
        self.assertEqual(len(dict_ground_truth['Test']), 12794)
Пример #2
0
 def test_should_add_new_custom_protocol_correctly(self):
     database = AggregateDatabase(self.base_paths)
     database.set_new_custom_protocol(
         {"new_protocol": {
             "Train": None,
             "Test": None,
             "Dev": None
         }})
     self.assertTrue("new_protocol" in database.protocols)
Пример #3
0
    def test_get_all_labels(self):
        database = AggregateDatabase(self.base_paths)
        dict_all_labels = database.get_all_labels()
        grandtest_results = {'Train': 11125, 'Dev': 4215, 'Test': 12794}

        for subset, subset_labels in dict_all_labels.items():
            number_accesses_per_subset = 0
            for db, db_labels_dict in subset_labels.items():
                number_accesses_per_subset += len(db_labels_dict)
            self.assertEqual(grandtest_results[subset],
                             number_accesses_per_subset)
Пример #4
0
def main():
    databases = {
        '3dmad              (grandtest)': ThreedmadDatabase.info(),
        'aggregate-database (grandtest)': AggregateDatabase.info(),
        'casia-fasd         (grandtest)': CasiaFasdDatabase.info(),
        'casia-surf         (grandtest)': CasiaSurfDatabase.info(),
        'csmad              (grandtest)': CsmadDatabase.info(),
        'hkbu               (grandtest)': HkbuDatabase.info(),
        'msu_mfsd           (grandtest)': MsuMfsdDatabase.info(),
        'oulu_npu           (grandtest)': OuluNpuDatabase.info(),
        'replay-attack      (grandtest)': ReplayAttackDatabase.info(),
        'replay-mobile      (grandtest)': ReplayMobileDatabase.info(),
        'rose-youtu         (grandtest)': RoseYoutuDatabase.info(),
        'siw                (grandtest)': SiwDatabase.info(),
        'uvad               (grandtest)': UvadDatabase.info()
    }

    table = []
    for db in sorted(databases):
        database_row = [
            db, databases[db]['users'], databases[db]["Train videos"],
            databases[db]["Dev videos"], databases[db]["Test videos"]
        ]
        table.append(database_row)

    headers = [
        "Database", "Number of Users", "Train videos", "Dev videos",
        "Test videos"
    ]
    print("bob.gradiant.face.databases:")
    print(tabulate(table, headers, tablefmt="fancy_grid"))
Пример #5
0
def main():
    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--json-file',
                        type=str,
                        dest='json_file',
                        help='json file where the ROOT_PATH of each dataset is defined', required=True)

    args = parser.parse_args()

    if not has_args(args):
        parser.print_help()
    else:

        base_output_path = 'gradgpad_protocols'
        os.makedirs(base_output_path, exist_ok=True)

        export_database_paths_from_file(args.json_file)  # temporary
        aggregate_database = get_database_from_key('aggregate-database')  # returns an object of AggregateDatabase class

        available_protocols = list(AggregateDatabase.get_available_protocols().keys())

        for protocol in available_protocols:

            protocol_output_path = '{}/{}'.format(base_output_path, protocol)
            os.makedirs(protocol_output_path, exist_ok=True)

            content = aggregate_database.get_ground_truth_list(protocol)
            for subset, accesses in content.items():
                filename = '{}/protocol_{}_{}.txt'.format(protocol_output_path, protocol, subset.lower())
                write_dict_to_file(filename, accesses)
Пример #6
0
def main():
    available_protocols_list = AggregateDatabase.get_available_protocols(
    ).keys()
    available_protocols = AggregateDatabase.get_available_protocols()
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('-l',
                        '--list',
                        dest='show_list',
                        action='store_true',
                        help='list of available protocols')
    parser.add_argument(
        '-p',
        '--protocol',
        dest='protocol',
        help='It will show you the dict of a available protocol {}'.format(
            available_protocols_list))
    parser.add_argument(
        '-nsd',
        '--no-show-datasets',
        dest='no_show_datasets',
        action='store_true',
        help='it will not show dataset info (on the protocol dict)')

    args = parser.parse_args()

    if not has_args(args):
        parser.print_help()
    else:
        if args.show_list:
            show_aggregate_table()
            show_available_protocols_on_aggregate_database_table(
                available_protocols)
        else:
            if args.protocol not in available_protocols_list:
                raise ValueError("protocol \"{}\" is not available. "
                                 "Try with {}".format(
                                     args.protocol, available_protocols_list))
            print("{}:".format(args.protocol))
            if args.no_show_datasets:
                for subset in ["Train", "Dev", "Test"]:
                    del available_protocols[args.protocol][subset]["datasets"]
            print(
                json.dumps(available_protocols[args.protocol],
                           indent=2).replace("null", "None"))
def get_available_protocols():
    available_protocols = list(
        AggregateDatabase.get_available_protocols().keys())

    if 'USE_UVAD' not in os.environ:
        if "cross-dataset-test-uvad" in available_protocols:
            available_protocols.remove("cross-dataset-test-casia-surf")

    return sorted(available_protocols)
Пример #8
0
    def test_should_run_well_with_one_pai_mask(self):
        pai = 'mask'
        parsed_datasets = AggregateDatabase.get_parsed_databases()
        protocol = get_one_pai_protocol(parsed_datasets, pai)
        allowed_pai = [0, 7, 8, 9]

        filtered_labels = filter_labels_by_protocol(protocol, self.dict_all_labels)

        for subset in filtered_labels:
            for basename in filtered_labels[subset]:
                common_pai_value = filtered_labels[subset][basename]["common_pai"]
                self.assertTrue(common_pai_value in allowed_pai)
Пример #9
0
    def test_should_run_well_with_unseen_attack_mask(self):
        pai = 'mask'
        parsed_datasets = AggregateDatabase.get_parsed_databases()
        available_pais = COMMON_PAI_CATEGORISATION.keys()
        protocol = get_unseen_attack_protocol(parsed_datasets, pai, available_pais)

        allowed_pai = {"Train": [0, 1, 2, 3, 4, 5, 6],
                       "Dev": [0, 1, 2, 3, 4, 5, 6],
                       "Test": [0, 7, 8, 9]
                       }

        filtered_labels = filter_labels_by_protocol(protocol, self.dict_all_labels)

        for subset in filtered_labels:
            for basename in filtered_labels[subset]:
                common_pai_value = filtered_labels[subset][basename]["common_pai"]
                self.assertTrue(common_pai_value in allowed_pai[subset])
Пример #10
0
def show_aggregate_table():
    databases = {
        'aggregate-database (grandtest)': AggregateDatabase.info(),
    }

    table = []
    for db in sorted(databases.keys()):
        database_row = [
            db, databases[db]['users'], databases[db]["Train videos"],
            databases[db]["Dev videos"], databases[db]["Test videos"]
        ]
        table.append(database_row)

    headers = [
        "Database", "Number of Users", "Train videos", "Dev videos",
        "Test videos"
    ]
    print(tabulate(table, headers, tablefmt="fancy_grid"))
 def test_should_not_throw_any_exception_for_available_protocols(self):
     for protocol in AggregateDatabase.get_available_protocols().values():
         protocol_checker(protocol)
Пример #12
0
    def test_get_all_accesses(self):
        database = AggregateDatabase(self.base_paths)
        dict_all_accesses = database.get_all_accesses()

        self.assertEqual(len(dict_all_accesses['All']), 28134)
Пример #13
0
 def test_is_a_collection_of_databases_static_method(self):
     self.assertTrue(AggregateDatabase.is_a_collection_of_databases())
Пример #14
0
 def test_name_static_method(self):
     self.assertEqual(AggregateDatabase.name(), 'aggregate-database')
Пример #15
0
 def test_constructor_with_non_existing_path(self):
     self.assertRaises(IOError, lambda: AggregateDatabase('wrong_path'))
Пример #16
0
 def setUp(self):
     self.dict_all_labels = TestResources.get_aggregated_database_all_dict_labels()
     self.available_protocols = AggregateDatabase.get_available_protocols()