def setUp(self): path = './test/test_files/json_data_file.json' source_type = 'json' params = {'path': path} self.dinv = DatasetInventoryMaster() self.dinv.create_from_source(source_type, params) self.inventory = self.dinv.data_inventory
def csv(args): """ Import From CSV """ params = {'path': args['path'], 'image_path_col_list': args['image_fields'], 'capture_id_col': args['capture_id_field'], 'attributes_col_list': args['label_fields'], 'meta_col_list': args['meta_data_fields']} dinv = DatasetInventoryMaster() dinv.create_from_source('csv', params) return dinv
def panthera(args): """ Import From panthera """ params = {'path': args['path']} dinv = DatasetInventoryMaster() dinv.create_from_source('panthera', params) return dinv
def class_dir(args): """ Import From Class Dirs""" params = {'path': args['path']} dinv = DatasetInventoryMaster() dinv.create_from_source('image_dir', params) return dinv
def json(args): """ Import From Json """ params = {'path': args['path']} dinv = DatasetInventoryMaster() dinv.create_from_source('json', params) return dinv
class DataInventoryTests(unittest.TestCase): """ Test Creation of Dataset Inventory """ def setUp(self): path = './test/test_files/json_data_file.json' source_type = 'json' params = {'path': path} self.dinv = DatasetInventoryMaster() self.dinv.create_from_source(source_type, params) self.inventory = self.dinv.data_inventory def testRemoveRecord(self): self.assertIn("single_species_standard", self.inventory) self.dinv.remove_record("single_species_standard") self.assertNotIn("single_species_standard", self.inventory) def testRemoveRecordsWithLabel(self): label_names = ['class', 'counts'] label_values = ['elephant', '12'] self.assertIn("is_elephant", self.inventory) self.assertIn("counts_is_12", self.inventory) self.dinv.remove_records_with_label(label_names, label_values) self.assertNotIn("is_elephant", self.inventory) self.assertNotIn("counts_is_12", self.inventory) def testKeepOnlyRecordsWithLabel(self): label_names = ['class', 'counts'] label_values = ['elephant', '12'] self.assertIn("is_elephant", self.inventory) self.assertIn("single_species_standard", self.inventory) self.assertIn("counts_is_12", self.inventory) self.dinv.keep_only_records_with_label(label_names, label_values) self.assertNotIn("single_species_standard", self.inventory) self.assertIn("is_elephant", self.inventory) self.assertIn("counts_is_12", self.inventory) def testConvertToTFRecordFormat(self): id = 'single_species_standard' self.dinv._map_labels_to_numeric() record = self.inventory[id] tfr_dict = self.dinv._convert_record_to_tfr_format(id, record) self.assertEqual(tfr_dict['id'], 'single_species_standard') self.assertEqual(tfr_dict['n_images'], 3) self.assertEqual(tfr_dict["image_paths"], [ "\\images\\4715\\all\\cat\\10296725_0.jpeg", "\\images\\4715\\all\\cat\\10296726_0.jpeg", "\\images\\4715\\all\\cat\\10296727_0.jpeg" ]) self.assertIsInstance(tfr_dict["label_num/class"][0], int) self.assertEqual(tfr_dict["label_num/color_brown"], [1]) self.assertEqual(tfr_dict["label_num/color_white"], [0]) self.assertIsInstance(tfr_dict["label_num/counts"][0], int) self.assertEqual(tfr_dict["label/class"], ['cat']) self.assertEqual(tfr_dict["label/color_brown"], ['1']) self.assertEqual(tfr_dict["label/color_white"], ['0']) self.assertEqual(tfr_dict["label/counts"], ['1']) def testRemoveMissingLabelRecords(self): self.assertIn("missing_counts_label", self.inventory) self.assertIn("counts_is_12", self.inventory) self.dinv._remove_records_with_any_missing_label() self.assertNotIn("missing_counts_label", self.inventory) self.assertIn("counts_is_12", self.inventory)
def main(): # Parse command line arguments parser = argparse.ArgumentParser(prog='CREATE DATASET') parser.add_argument("-inventory", type=str, required=True, help="path to inventory json file") parser.add_argument("-output_dir", type=str, required=True, help="Directory to which TFRecord files are written") parser.add_argument( "-log_outdir", type=str, required=False, default=None, help="Directory to write logfiles to (defaults to output_dir)") parser.add_argument("-split_names", nargs='+', type=str, help='split dataset into these named splits', default=['train', 'val', 'test'], required=False) parser.add_argument("-split_percent", nargs='+', type=float, help='split dataset into these proportions', default=[0.9, 0.05, 0.05], required=False) parser.add_argument("-split_by_meta", type=str, help='split dataset by meta data field in inventory', default=None, required=False) parser.add_argument("-balanced_sampling_min", default=False, action='store_true', required=False, help="sample labels balanced to the least frequent\ value") parser.add_argument("-balanced_sampling_label", default=None, type=str, help='label used for balanced sampling') parser.add_argument("-remove_label_name", nargs='+', type=str, default=None, help='remove records with label names (a list) and \ corresponding remove_label_value', required=False) parser.add_argument("-remove_label_value", nargs='+', type=str, default=None, help='remove records with label value (a list) and \ corresponding remove_label_name', required=False) parser.add_argument("-keep_label_name", nargs='+', type=str, default=None, help='keep only records with at least one of the \ label names (a list) and \ corresponding keep_label_value', required=False) parser.add_argument("-keep_label_value", nargs='+', type=str, default=None, help='keep only records with label value (a list) and \ corresponding keep_label_name', required=False) parser.add_argument("-remove_multi_label_records", default=True, action='store_true', required=False, help="whether to remove records with more than one \ observation (multi-label) which is not currently\ supported in model training") parser.add_argument("-image_root_path", type=str, default=None, help='Root path of all images - will be appended to\ the image paths stored in the dataset inventory', required=False) parser.add_argument("-image_save_side_max", type=int, default=500, required=False, help="aspect preserving resizeing of images such that\ the larger side of each image has that\ many pixels, typically at least 330\ (depending on the model architecture)") parser.add_argument("-overwrite", default=False, action='store_true', required=False, help="whether to overwrite existing tfr files") parser.add_argument("-write_tfr_in_parallel", default=False, action='store_true', required=False, help="whether to write tfrecords in parallel if more \ than one is created (preferably use \ 'process_images_in_parallel')") parser.add_argument("-process_images_in_parallel", default=False, action='store_true', required=False, help="whether to process images in parallel \ (only if 'write_tfr_in_parallel' is false)") parser.add_argument("-process_images_in_parallel_size", type=int, default=320, required=False, help="if processing images in parallel - how many per \ process, this can influene memory requirements") parser.add_argument("-processes_images_in_parallel_n_processes", type=int, default=4, required=False, help="if processing images in parallel - how many \ processes to use (default 4)") parser.add_argument("-max_records_per_file", type=int, default=5000, required=False, help="The max number of records per TFRecord file.\ Multiple files are generated if the size of\ the dataset exceeds this value. It is recommended\ to use large values (default 5000)") # Parse command line arguments args = vars(parser.parse_args()) # Configure Logging if args['log_outdir'] is None: args['log_outdir'] = args['output_dir'] setup_logging(log_output_path=args['log_outdir']) logger = logging.getLogger(__name__) print("Using arguments:") for k, v in args.items(): print("Arg: %s: %s" % (k, v)) # Create Dataset Inventory params = {'path': args['inventory']} dinv = DatasetInventoryMaster() dinv.create_from_source('json', params) # Remove multi-label subjects if args['remove_multi_label_records']: dinv.remove_multi_label_records() # Remove specific labels if args['remove_label_name'] is not None: if args['remove_label_value'] is None: raise ValueError('if remove_label_name is specified\ remove_label_value needs to be specified') dinv.remove_records_with_label( label_name_list=args['remove_label_name'], label_value_list=args['remove_label_value']) # keep only specific labels if args['keep_label_name'] is not None: if args['keep_label_value'] is None: raise ValueError('if keep_label_name is specified\ keep_label_value needs to be specified') dinv.keep_only_records_with_label( label_name_list=args['keep_label_name'], label_value_list=args['keep_label_value']) # Log Statistics dinv.log_stats() # Determine if Meta-Column has been specified if args['split_by_meta'] is not None: logger.debug("Splitting by metadata %s" % args['split_by_meta']) if args['balanced_sampling_min']: splitted = dinv.split_inventory_by_meta_data_column_and_balanced_sampling( meta_colum=args['split_by_meta'], split_label_min=args['balanced_sampling_label']) logger.debug("Balanced sampling using %s" % args['split_by_meta']) else: splitted = dinv.split_inventory_by_meta_data_column( meta_colum=args['split_by_meta']) # Determine if balanced sampling is requested elif args['balanced_sampling_min']: if args['balanced_sampling_label'] is None: raise ValueError("balanced_sampling_label must be specified if \ balanced_sampling_min is set to true") logger.debug("Splitting by random balanced sampling") splitted = dinv.split_inventory_by_random_splits_with_balanced_sample( split_label_min=args['balanced_sampling_label'], split_names=args['split_names'], split_percent=args['split_percent']) # Split without balanced sampling else: logger.debug("Splitting randomly") splitted = dinv.split_inventory_by_random_splits( split_names=args['split_names'], split_percent=args['split_percent']) # Log all the splits to create for i, split_name in enumerate(splitted.keys()): logger.info("Created split %s - %s" % (i, split_name)) # Log Statistics for different splits for split_name, split_data in splitted.items(): logger.debug("Stats for Split %s" % split_name) split_data.log_stats(debug_only=True) # Write Label Mappings out_label_mapping = args['output_dir'] + 'label_mapping.json' dinv.export_label_mapping(out_label_mapping) # Write TFrecord files tfr_encoder_decoder = DefaultTFRecordEncoderDecoder() tfr_writer = DatasetWriter(tfr_encoder_decoder.encode_record) counter = 0 n_splits = len(splitted.keys()) for split_name, split_data in splitted.items(): counter += 1 logger.info("Starting to process %s (%s / %s)" % (split_name, counter, n_splits)) split_data.export_to_tfrecord( tfr_writer, args['output_dir'], file_prefix=split_name, image_root_path=args['image_root_path'], image_pre_processing_fun=read_resize_convert_to_jpeg, image_pre_processing_args={ "max_side": args['image_save_side_max'] }, random_shuffle_before_save=True, overwrite_existing_files=args['overwrite'], max_records_per_file=args['max_records_per_file'], write_tfr_in_parallel=args['write_tfr_in_parallel'], process_images_in_parallel=args['process_images_in_parallel'], process_images_in_parallel_size=args[ 'process_images_in_parallel_size'], processes_images_in_parallel_n_processes=args[ 'processes_images_in_parallel_n_processes']) logger.info("Finished writing TFRecords")