Ejemplo n.º 1
0
 def setUp(self):
     path = './test/test_files/json_data_file.json'
     source_type = 'json'
     params = {'path': path}
     self.dinv = DatasetInventoryMaster()
     self.dinv.create_from_source(source_type, params)
     self.inventory = self.dinv.data_inventory
Ejemplo n.º 2
0
def csv(args):
    """ Import From CSV """
    params = {'path': args['path'],
              'image_path_col_list': args['image_fields'],
              'capture_id_col': args['capture_id_field'],
              'attributes_col_list': args['label_fields'],
              'meta_col_list': args['meta_data_fields']}
    dinv = DatasetInventoryMaster()
    dinv.create_from_source('csv', params)
    return dinv
Ejemplo n.º 3
0
def panthera(args):
    """ Import From panthera """
    params = {'path': args['path']}
    dinv = DatasetInventoryMaster()
    dinv.create_from_source('panthera', params)
    return dinv
Ejemplo n.º 4
0
def class_dir(args):
    """ Import From Class Dirs"""
    params = {'path': args['path']}
    dinv = DatasetInventoryMaster()
    dinv.create_from_source('image_dir', params)
    return dinv
Ejemplo n.º 5
0
def json(args):
    """ Import From Json """
    params = {'path': args['path']}
    dinv = DatasetInventoryMaster()
    dinv.create_from_source('json', params)
    return dinv
Ejemplo n.º 6
0
class DataInventoryTests(unittest.TestCase):
    """ Test Creation of Dataset Inventory """
    def setUp(self):
        path = './test/test_files/json_data_file.json'
        source_type = 'json'
        params = {'path': path}
        self.dinv = DatasetInventoryMaster()
        self.dinv.create_from_source(source_type, params)
        self.inventory = self.dinv.data_inventory

    def testRemoveRecord(self):
        self.assertIn("single_species_standard", self.inventory)
        self.dinv.remove_record("single_species_standard")
        self.assertNotIn("single_species_standard", self.inventory)

    def testRemoveRecordsWithLabel(self):
        label_names = ['class', 'counts']
        label_values = ['elephant', '12']
        self.assertIn("is_elephant", self.inventory)
        self.assertIn("counts_is_12", self.inventory)
        self.dinv.remove_records_with_label(label_names, label_values)
        self.assertNotIn("is_elephant", self.inventory)
        self.assertNotIn("counts_is_12", self.inventory)

    def testKeepOnlyRecordsWithLabel(self):
        label_names = ['class', 'counts']
        label_values = ['elephant', '12']
        self.assertIn("is_elephant", self.inventory)
        self.assertIn("single_species_standard", self.inventory)
        self.assertIn("counts_is_12", self.inventory)
        self.dinv.keep_only_records_with_label(label_names, label_values)
        self.assertNotIn("single_species_standard", self.inventory)
        self.assertIn("is_elephant", self.inventory)
        self.assertIn("counts_is_12", self.inventory)

    def testConvertToTFRecordFormat(self):
        id = 'single_species_standard'
        self.dinv._map_labels_to_numeric()
        record = self.inventory[id]
        tfr_dict = self.dinv._convert_record_to_tfr_format(id, record)
        self.assertEqual(tfr_dict['id'], 'single_species_standard')
        self.assertEqual(tfr_dict['n_images'], 3)
        self.assertEqual(tfr_dict["image_paths"], [
            "\\images\\4715\\all\\cat\\10296725_0.jpeg",
            "\\images\\4715\\all\\cat\\10296726_0.jpeg",
            "\\images\\4715\\all\\cat\\10296727_0.jpeg"
        ])
        self.assertIsInstance(tfr_dict["label_num/class"][0], int)
        self.assertEqual(tfr_dict["label_num/color_brown"], [1])
        self.assertEqual(tfr_dict["label_num/color_white"], [0])
        self.assertIsInstance(tfr_dict["label_num/counts"][0], int)
        self.assertEqual(tfr_dict["label/class"], ['cat'])
        self.assertEqual(tfr_dict["label/color_brown"], ['1'])
        self.assertEqual(tfr_dict["label/color_white"], ['0'])
        self.assertEqual(tfr_dict["label/counts"], ['1'])

    def testRemoveMissingLabelRecords(self):
        self.assertIn("missing_counts_label", self.inventory)
        self.assertIn("counts_is_12", self.inventory)
        self.dinv._remove_records_with_any_missing_label()
        self.assertNotIn("missing_counts_label", self.inventory)
        self.assertIn("counts_is_12", self.inventory)
def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(prog='CREATE DATASET')
    parser.add_argument("-inventory",
                        type=str,
                        required=True,
                        help="path to inventory json file")
    parser.add_argument("-output_dir",
                        type=str,
                        required=True,
                        help="Directory to which TFRecord files are written")
    parser.add_argument(
        "-log_outdir",
        type=str,
        required=False,
        default=None,
        help="Directory to write logfiles to (defaults to output_dir)")
    parser.add_argument("-split_names",
                        nargs='+',
                        type=str,
                        help='split dataset into these named splits',
                        default=['train', 'val', 'test'],
                        required=False)
    parser.add_argument("-split_percent",
                        nargs='+',
                        type=float,
                        help='split dataset into these proportions',
                        default=[0.9, 0.05, 0.05],
                        required=False)
    parser.add_argument("-split_by_meta",
                        type=str,
                        help='split dataset by meta data field in inventory',
                        default=None,
                        required=False)
    parser.add_argument("-balanced_sampling_min",
                        default=False,
                        action='store_true',
                        required=False,
                        help="sample labels balanced to the least frequent\
                              value")
    parser.add_argument("-balanced_sampling_label",
                        default=None,
                        type=str,
                        help='label used for balanced sampling')
    parser.add_argument("-remove_label_name",
                        nargs='+',
                        type=str,
                        default=None,
                        help='remove records with label names (a list) and \
                              corresponding remove_label_value',
                        required=False)
    parser.add_argument("-remove_label_value",
                        nargs='+',
                        type=str,
                        default=None,
                        help='remove records with label value (a list) and \
                              corresponding remove_label_name',
                        required=False)
    parser.add_argument("-keep_label_name",
                        nargs='+',
                        type=str,
                        default=None,
                        help='keep only records with at least one of the \
                              label names (a list) and \
                              corresponding keep_label_value',
                        required=False)
    parser.add_argument("-keep_label_value",
                        nargs='+',
                        type=str,
                        default=None,
                        help='keep only records with label value (a list) and \
                              corresponding keep_label_name',
                        required=False)
    parser.add_argument("-remove_multi_label_records",
                        default=True,
                        action='store_true',
                        required=False,
                        help="whether to remove records with more than one \
                              observation (multi-label) which is not currently\
                              supported in model training")
    parser.add_argument("-image_root_path",
                        type=str,
                        default=None,
                        help='Root path of all images - will be appended to\
                              the image paths stored in the dataset inventory',
                        required=False)
    parser.add_argument("-image_save_side_max",
                        type=int,
                        default=500,
                        required=False,
                        help="aspect preserving resizeing of images such that\
                              the larger side of each image has that\
                              many pixels, typically at least 330\
                              (depending on the model architecture)")
    parser.add_argument("-overwrite",
                        default=False,
                        action='store_true',
                        required=False,
                        help="whether to overwrite existing tfr files")
    parser.add_argument("-write_tfr_in_parallel",
                        default=False,
                        action='store_true',
                        required=False,
                        help="whether to write tfrecords in parallel if more \
                              than one is created (preferably use \
                              'process_images_in_parallel')")
    parser.add_argument("-process_images_in_parallel",
                        default=False,
                        action='store_true',
                        required=False,
                        help="whether to process images in parallel \
                              (only if 'write_tfr_in_parallel' is false)")
    parser.add_argument("-process_images_in_parallel_size",
                        type=int,
                        default=320,
                        required=False,
                        help="if processing images in parallel - how many per \
                              process, this can influene memory requirements")
    parser.add_argument("-processes_images_in_parallel_n_processes",
                        type=int,
                        default=4,
                        required=False,
                        help="if processing images in parallel - how many \
                              processes to use (default 4)")
    parser.add_argument("-max_records_per_file",
                        type=int,
                        default=5000,
                        required=False,
                        help="The max number of records per TFRecord file.\
                             Multiple files are generated if the size of\
                             the dataset exceeds this value. It is recommended\
                             to use large values (default 5000)")

    # Parse command line arguments
    args = vars(parser.parse_args())

    # Configure Logging
    if args['log_outdir'] is None:
        args['log_outdir'] = args['output_dir']

    setup_logging(log_output_path=args['log_outdir'])

    logger = logging.getLogger(__name__)

    print("Using arguments:")
    for k, v in args.items():
        print("Arg: %s: %s" % (k, v))

    # Create Dataset Inventory
    params = {'path': args['inventory']}
    dinv = DatasetInventoryMaster()
    dinv.create_from_source('json', params)

    # Remove multi-label subjects
    if args['remove_multi_label_records']:
        dinv.remove_multi_label_records()

    # Remove specific labels
    if args['remove_label_name'] is not None:
        if args['remove_label_value'] is None:
            raise ValueError('if remove_label_name is specified\
                              remove_label_value needs to be specified')

        dinv.remove_records_with_label(
            label_name_list=args['remove_label_name'],
            label_value_list=args['remove_label_value'])

    # keep only specific labels
    if args['keep_label_name'] is not None:
        if args['keep_label_value'] is None:
            raise ValueError('if keep_label_name is specified\
                              keep_label_value needs to be specified')

        dinv.keep_only_records_with_label(
            label_name_list=args['keep_label_name'],
            label_value_list=args['keep_label_value'])

    # Log Statistics
    dinv.log_stats()

    # Determine if Meta-Column has been specified
    if args['split_by_meta'] is not None:
        logger.debug("Splitting by metadata %s" % args['split_by_meta'])
        if args['balanced_sampling_min']:
            splitted = dinv.split_inventory_by_meta_data_column_and_balanced_sampling(
                meta_colum=args['split_by_meta'],
                split_label_min=args['balanced_sampling_label'])
            logger.debug("Balanced sampling using %s" % args['split_by_meta'])
        else:
            splitted = dinv.split_inventory_by_meta_data_column(
                meta_colum=args['split_by_meta'])

    # Determine if balanced sampling is requested
    elif args['balanced_sampling_min']:
        if args['balanced_sampling_label'] is None:
            raise ValueError("balanced_sampling_label must be specified if \
                              balanced_sampling_min is set to true")
        logger.debug("Splitting by random balanced sampling")
        splitted = dinv.split_inventory_by_random_splits_with_balanced_sample(
            split_label_min=args['balanced_sampling_label'],
            split_names=args['split_names'],
            split_percent=args['split_percent'])

    # Split without balanced sampling
    else:
        logger.debug("Splitting randomly")
        splitted = dinv.split_inventory_by_random_splits(
            split_names=args['split_names'],
            split_percent=args['split_percent'])

    # Log all the splits to create
    for i, split_name in enumerate(splitted.keys()):
        logger.info("Created split %s - %s" % (i, split_name))

    # Log Statistics for different splits
    for split_name, split_data in splitted.items():
        logger.debug("Stats for Split %s" % split_name)
        split_data.log_stats(debug_only=True)

    # Write Label Mappings
    out_label_mapping = args['output_dir'] + 'label_mapping.json'
    dinv.export_label_mapping(out_label_mapping)

    # Write TFrecord files
    tfr_encoder_decoder = DefaultTFRecordEncoderDecoder()
    tfr_writer = DatasetWriter(tfr_encoder_decoder.encode_record)

    counter = 0
    n_splits = len(splitted.keys())
    for split_name, split_data in splitted.items():
        counter += 1
        logger.info("Starting to process %s (%s / %s)" %
                    (split_name, counter, n_splits))
        split_data.export_to_tfrecord(
            tfr_writer,
            args['output_dir'],
            file_prefix=split_name,
            image_root_path=args['image_root_path'],
            image_pre_processing_fun=read_resize_convert_to_jpeg,
            image_pre_processing_args={
                "max_side": args['image_save_side_max']
            },
            random_shuffle_before_save=True,
            overwrite_existing_files=args['overwrite'],
            max_records_per_file=args['max_records_per_file'],
            write_tfr_in_parallel=args['write_tfr_in_parallel'],
            process_images_in_parallel=args['process_images_in_parallel'],
            process_images_in_parallel_size=args[
                'process_images_in_parallel_size'],
            processes_images_in_parallel_n_processes=args[
                'processes_images_in_parallel_n_processes'])
    logger.info("Finished writing TFRecords")