def _load_taxon_hierarchy(self): """Load the taxon hierarchy. Must be separate from the constructor because :meth:`set_photo_count_min` influences the taxon hierarchy. """ session, metadata = db.get_session_or_error() if not self.taxon_hr: self.taxon_hr = db.get_taxon_hierarchy(session, metadata)
def get_taxon_hierarchy(self): """Return the taxon hierarchy. First tries to get the taxon hierarchy from the metadata database. If that fails, it will try to get it from the configuration file. """ try: session, metadata = db.get_session_or_error() hr = db.get_taxon_hierarchy(session, metadata) except DatabaseSessionError: hr = self.config.classification.taxa.as_dict() return hr
def batch_train(self, data_dir, output_dir): """Batch train neural networks. Training data is obtained from the directory `data_dir` and the neural networks are saved to the directory `output_dir`. Which training data to train on is set in the classification hierarchy of the configurations. """ session, metadata = db.get_session_or_error() # Must not be loaded in the constructor, in case set_photo_count_min() # is used. self._load_taxon_hierarchy() # Get the name of each level in the classification hierarchy. levels = [l.name for l in self.class_hr] # Train an ANN for each path in the classification hierarchy. for filter_ in classification_hierarchy_filters(levels, self.taxon_hr): level = levels.index(filter_.get('class')) train_file = os.path.join(data_dir, self.class_hr[level].train_file) ann_file = os.path.join(output_dir, self.class_hr[level].ann_file) if 'ann' not in self.class_hr[level]: config = None else: config = self.class_hr[level].ann # Replace any placeholders in the paths. where = filter_.get('where', {}) for key, val in where.items(): val = val if val is not None else '_' train_file = train_file.replace("__%s__" % key, val) ann_file = ann_file.replace("__%s__" % key, val) # Get the classification categories from the database. classes = db.get_classes_from_filter(session, metadata, filter_) assert len(classes) > 0, \ "No classes found for filter `%s`" % filter_ # Skip train data export if there is only one class for this filter. if not len(classes) > 1: logging.debug("Only one class for this filter. Skipping " \ "training of %s" % ann_file) continue # Train the ANN. logging.info("Training network `%s` with training data " \ "from `%s` ..." % (ann_file, train_file)) try: self.train(train_file, ann_file, config) except FileExistsError as e: # Don't train if the file already exists. logging.warning("Skipping: %s" % e)
def make(self, image_dir, cache_dir, config, update=False): """Cache features for an image directory to disk. One cache file is created for each feature configuration set in the configurations. The caches are saved in the target directory `cache_dir`. Each cache file is a Python shelve, a persistent, dictionary-like object. If `update` is set to True, existing features are updated. Method :meth:`get_phenotype` can then be used to retrieve these features and combined them to phenotypes. """ session, metadata = db.get_session_or_error() phenotyper = Phenotyper() # Get a list of all the photos in the database. photos = db.get_photos(session, metadata) # Cache each feature for each photo separately. One cache per # feature type is created, and each cache contains the features # for all images. for hash_, c in self.get_single_feature_configurations(config): if not os.path.isdir(cache_dir): os.makedirs(cache_dir) cache_path = os.path.join(cache_dir, str(hash_)) sys.stderr.write("Caching features in `%s`...\n" % cache_path) # Create a shelve for storing features. Empty existing shelves. cache = shelve.open(cache_path) # Set the new config. phenotyper.set_config(c) # Cache the feature for each photo. for photo in photos: # Skip feature extraction if the feature already exists, unless # update is set to True. if not update and str(photo.md5sum) in cache: continue logging.info("Processing photo %s...", photo.id) # Extract a feature and cache it. im_path = os.path.join(image_dir, photo.path) phenotyper.set_image(im_path) cache[str(photo.md5sum)] = phenotyper.make() cache.close()
def k_fold_xval_stratified(self, k=3, autoskip=False): """Perform stratified K-folds cross validation. The number of folds `k` must be at least 2. The minimum number of members for any class cannot be less than `k`, or an AssertionError is raised. If `autoskip` is set to True, only the members for classes with at least `k` members are used for the cross validation. """ session, metadata = db.get_session_or_error() # Will hold the score of each folds. scores = {} # Get a list of all the photo IDs in the database. samples = db.get_photos_with_taxa(session, metadata) photo_ids, classes, photo_count_min = self.__get_photo_ids( samples, k, autoskip) train_data, trainer, tester = self.__set_trainer_and_tester( photo_count_min) # Obtain cross validation folds. folds = cross_validation.StratifiedKFold(classes, k) result_dir = os.path.join(self.temp_dir, 'results') for i, (train_idx, test_idx) in enumerate(folds): self.__make_data_directories(result_dir, i) # Make train and test data for this fold. self.__make_data_subset(photo_ids, train_idx, train_data, 'train') self.__make_data_subset(photo_ids, test_idx, train_data, 'test') # Train neural networks on training data. trainer.batch_train(data_dir=self.train_dir, output_dir=self.ann_dir) scores = self.__get_scores(tester, scores) return scores
def test_with_hierarchy(self, test_data_dir, ann_dir, max_error=0.001): """Test each ANN in a classification hierarchy and export results. Returns a 2-tuple ``(correct,total)``. """ session, metadata = db.get_session_or_error() logging.info("Testing the neural networks hierarchy...") self.classifications = {} self.classifications_expected = {} # Get the taxonomic hierarchy from the database. self.taxon_hr = db.get_taxon_hierarchy(session, metadata) # Get the classification hierarchy from the configurations. try: self.class_hr = self.config.classification.hierarchy except: raise ConfigurationError("classification hierarchy not set") # Get the name of each level in the classification hierarchy. levels = [l.name for l in self.class_hr] # Get the prefix for the classification columns. try: dependent_prefix = self.config.data.dependent_prefix except: dependent_prefix = OUTPUT_PREFIX # Get the expected and recognized classification for each sample in # the test data. for filter_ in classification_hierarchy_filters(levels, self.taxon_hr): logging.info("Classifying on %s" % readable_filter(filter_)) level_name = filter_.get('class') level_n = levels.index(level_name) level = self.class_hr[level_n] test_file = os.path.join(test_data_dir, level.test_file) ann_file = os.path.join(ann_dir, level.ann_file) # Set the maximum error for classification. try: max_error = level.max_error except: pass # Replace any placeholders in the paths. where = filter_.get('where', {}) for key, val in where.items(): val = val if val is not None else '_' test_file = test_file.replace("__%s__" % key, val) ann_file = ann_file.replace("__%s__" % key, val) # Get the class names for this filter. classes = db.get_classes_from_filter(session, metadata, filter_) assert len(classes) > 0, \ "No classes found for filter `%s`" % filter_ # Get the codeword for each class. codewords = get_codewords(classes) # Load the ANN. if len(classes) > 1: ann = libfann.neural_net() ann.create_from_file(str(ann_file)) # Load the test data. test_data = TrainData() test_data.read_from_file(test_file, dependent_prefix) # Test each sample in the test data. for label, input_, output in test_data: assert len(codewords) == len(output), \ "Codeword size mismatch. Codeword has {0} bits, but the " \ "training data has {1} output bits.".\ format(len(codewords), len(output)) # Obtain the photo ID from the label. if not label: raise ValueError("Test sample is missing a label with " \ "photo ID") try: photo_id = self.re_photo_id.search(label).group(1) photo_id = int(photo_id) except: raise RuntimeError("Failed to obtain the photo ID from " \ "the sample label") # Skip classification if there is only one class for this # filter. if not len(classes) > 1: logging.debug("Not enough classes for filter. Skipping " \ "testing of %s" % ann_file) self.classifications[photo_id][level_name] = [''] self.classifications_expected[photo_id][level_name] = [''] continue # Set the expected class. class_expected = get_classification(codewords, output, max_error) class_expected = [class_ for mse,class_ in class_expected] assert len(class_expected) == 1, \ "Class codewords must have one positive bit, found {0}".\ format(len(class_expected)) # Get the recognized class. codeword = ann.run(input_) class_ann = get_classification(codewords, codeword, max_error) class_ann = [class_ for mse,class_ in class_ann] # Save the classification at each level. if level_n == 0: self.classifications[photo_id] = {} self.classifications_expected[photo_id] = {} self.classifications[photo_id][level_name] = class_ann self.classifications_expected[photo_id][level_name] = class_expected ann.destroy() return self.get_correct_count()
def export_results(self, filename, filter_, error=0.01): """Export the classification results to a TSV file. Export the test results to a tab separated file `filename`. The class name for a codeword is obtained from the database `db_path`, using the classification filter `filter_`. A bit in a codeword is considered on if the mean square error for a bit is less or equal to `error`. """ session, metadata = db.get_session_or_error() if self.test_data is None: raise RuntimeError("Test data is not set") # Get the classification categories from the database. classes = db.get_classes_from_filter(session, metadata, filter_) assert len(classes) > 0, \ "No classes found for filter `%s`" % filter_ # Get the codeword for each class. codewords = get_codewords(classes) # Write results to file. with open(filename, 'w') as fh: # Write the header. fh.write( "%s\n" % "\t".join(['ID','Class','Classification','Match']) ) total = 0 correct = 0 for label, input, output in self.test_data: total += 1 row = [] if label: row.append(label) else: row.append("") if len(codewords) != len(output): raise ValueError("Codeword length ({0}) does not " \ "match output length ({1}). Is the classification " \ "filter correct?".format(len(codewords), len(output)) ) class_expected = get_classification(codewords, output, error) class_expected = [class_ for mse,class_ in class_expected] assert len(class_expected) == 1, \ "The codeword for a class can only have one positive value" row.append(class_expected[0]) codeword = self.ann.run(input) class_ann = get_classification(codewords, codeword, error) class_ann = [class_ for mse,class_ in class_ann] row.append(", ".join(class_ann)) # Assume a match if the first items of the classifications match. if len(class_ann) > 0 and class_ann[0] == class_expected[0]: row.append("+") correct += 1 else: row.append("-") fh.write( "%s\n" % "\t".join(row) ) # Calculate fraction correctly classified. fraction = float(correct) / total # Write correctly classified fraction. fh.write( "%s\n" % "\t".join(['','','',"%.3f" % fraction]) ) print "Correctly classified: %.1f%%" % (fraction*100) print "Testing results written to %s" % filename
def test_with_hierarchy(self, test_data_dir, ann_dir, max_error=0.001): """Test each ANN in a classification hierarchy and export results. Returns a 2-tuple ``(correct,total)``. """ session, metadata = db.get_session_or_error() logging.info("Testing the neural networks hierarchy...") self.classifications = {} self.classifications_expected = {} # Get the taxonomic hierarchy from the database. self.taxon_hr = db.get_taxon_hierarchy(session, metadata) # Get the classification hierarchy from the configurations. try: self.class_hr = self.config.classification.hierarchy except: raise ConfigurationError("classification hierarchy not set") # Get the name of each level in the classification hierarchy. levels = [l.name for l in self.class_hr] # Get the prefix for the classification columns. try: dependent_prefix = self.config.data.dependent_prefix except: dependent_prefix = OUTPUT_PREFIX # Get the expected and recognized classification for each sample in # the test data. for filter_ in classification_hierarchy_filters(levels, self.taxon_hr): logging.info("Classifying on %s" % readable_filter(filter_)) level_name = filter_.get('class') level_n = levels.index(level_name) level = self.class_hr[level_n] test_file = os.path.join(test_data_dir, level.test_file) ann_file = os.path.join(ann_dir, level.ann_file) # Set the maximum error for classification. try: max_error = level.max_error except: pass # Replace any placeholders in the paths. where = filter_.get('where', {}) for key, val in where.items(): val = val if val is not None else '_' test_file = test_file.replace("__%s__" % key, val) ann_file = ann_file.replace("__%s__" % key, val) # Get the class names for this filter. classes = db.get_classes_from_filter(session, metadata, filter_) assert len(classes) > 0, \ "No classes found for filter `%s`" % filter_ # Get the codeword for each class. codewords = get_codewords(classes) # Load the ANN. ann = libfann.neural_net() if len(classes) > 1: ann.create_from_file(str(ann_file)) # Load the test data. test_data = TrainData() test_data.read_from_file(test_file, dependent_prefix) # Test each sample in the test data. for label, input_, output in test_data: assert len(codewords) == len(output), \ "Codeword size mismatch. Codeword has {0} bits, but the " \ "training data has {1} output bits.".\ format(len(codewords), len(output)) # Obtain the photo ID from the label. if not label: raise ValueError("Test sample is missing a label with " \ "photo ID") try: photo_id = self.re_photo_id.search(label).group(1) photo_id = int(photo_id) except: raise RuntimeError("Failed to obtain the photo ID from " \ "the sample label") # Save the classification at each level. if level_n == 0: self.classifications[photo_id] = {} self.classifications_expected[photo_id] = {} # Skip classification if there is only one class for this # filter. if not len(classes) > 1: logging.debug("Not enough classes for filter. Skipping " \ "testing of %s" % ann_file) self.classifications[photo_id][level_name] = [''] self.classifications_expected[photo_id][level_name] = [''] continue # Set the expected class. class_expected = get_classification(codewords, output, max_error) class_expected = [class_ for mse,class_ in class_expected] assert len(class_expected) == 1, \ "Class codewords must have one positive bit, found {0}".\ format(len(class_expected)) # Get the recognized class. codeword = ann.run(input_) class_ann = get_classification(codewords, codeword, max_error) class_ann = [class_ for mse,class_ in class_ann] # Save the classification at each level. self.classifications[photo_id][level_name] = class_ann self.classifications_expected[photo_id][level_name] = class_expected ann.destroy() return self.get_correct_count()
def k_fold_xval_stratified(self, k=3, autoskip=False): """Perform stratified K-folds cross validation. The number of folds `k` must be at least 2. The minimum number of members for any class cannot be less than `k`, or an AssertionError is raised. If `autoskip` is set to True, only the members for classes with at least `k` members are used for the cross validation. """ session, metadata = db.get_session_or_error() # Will hold the score of each folds. scores = {} # Get a list of all the photo IDs in the database. samples = db.get_photos_with_taxa(session, metadata) # Get a list of the photo IDs and a list of the classes. The classes # are needed for the stratified cross validation. photo_ids = [] classes = [] for x in samples: photo_ids.append(x[0].id) tmp = np.array(x[1:]).astype(str) classes.append('_'.join(tmp)) # Numpy features are needed for these. photo_ids = np.array(photo_ids) classes = np.array(classes) # Count the number of each class. class_counts = Counter(classes) if autoskip: # Create a mask for the classes that have enough members and remove # the photo IDs that don't have enough members. mask = [] for i, c in enumerate(classes): if class_counts[c] >= k: mask.append(i) photo_ids = photo_ids[mask] classes = classes[mask] else: for label, count in class_counts.items(): assert count >= k, "Class {0} has only {1} members, which " \ "is too few. The minimum number of labels for any " \ "class cannot be less than k={2}. Use --autoskip to skip " \ "classes with too few members.".format(label, count, k) if autoskip: photo_count_min = k else: photo_count_min = 0 # Train data exporter. train_data = BatchMakeTrainData(self.config, self.cache_dir) train_data.set_photo_count_min(photo_count_min) # Set the trainer. trainer = BatchMakeAnn(self.config) trainer.set_photo_count_min(photo_count_min) if self.aivolver_config_path: trainer.set_training_method('aivolver', self.aivolver_config_path) # Set the ANN tester. tester = TestAnn(self.config) tester.set_photo_count_min(photo_count_min) # Obtain cross validation folds. folds = cross_validation.StratifiedKFold(classes, k) result_dir = os.path.join(self.temp_dir, 'results') for i, (train_idx, test_idx) in enumerate(folds): # Make data directories. train_dir = os.path.join(self.temp_dir, 'train', str(i)) test_dir = os.path.join(self.temp_dir, 'test', str(i)) ann_dir = os.path.join(self.temp_dir, 'ann', str(i)) test_result = os.path.join(result_dir, '{0}.tsv'.format(i)) for path in (train_dir,test_dir,ann_dir,result_dir): if not os.path.isdir(path): os.makedirs(path) # Make train data for this fold. train_samples = photo_ids[train_idx] train_data.set_subset(train_samples) train_data.batch_export(train_dir) # Make test data for this fold. test_samples = photo_ids[test_idx] train_data.set_subset(test_samples) train_data.batch_export(test_dir, train_dir) # Train neural networks on training data. trainer.batch_train(data_dir=train_dir, output_dir=ann_dir) # Calculate the score for this fold. tester.test_with_hierarchy(test_dir, ann_dir) tester.export_hierarchy_results(test_result) # List all level combinations. try: class_hr = self.config.classification.hierarchy hr = [level.name for level in class_hr] except: raise ConfigurationError("classification hierarchy not set") level_filters = [] ranks = [] for i in range(len(hr)): ranks.append(hr[i]) level_filters.append(ranks) level_filters = tuple(level_filters) for filter_ in level_filters: correct, total = tester.get_correct_count(filter_) score = float(correct) / total filter_s = "/".join(filter_) if filter_s not in scores: scores[filter_s] = [] scores[filter_s].append(score) return scores
def k_fold_xval_stratified(self, k=3, autoskip=False): """Perform stratified K-folds cross validation. The number of folds `k` must be at least 2. The minimum number of members for any class cannot be less than `k`, or an AssertionError is raised. If `autoskip` is set to True, only the members for classes with at least `k` members are used for the cross validation. """ session, metadata = db.get_session_or_error() # Will hold the score of each folds. scores = {} # Get a list of all the photo IDs in the database. samples = db.get_photos_with_taxa(session, metadata) # Get a list of the photo IDs and a list of the classes. The classes # are needed for the stratified cross validation. photo_ids = [] classes = [] for x in samples: photo_ids.append(x[0].id) tmp = np.array(x[1:]).astype(str) classes.append('_'.join(tmp)) # Numpy features are needed for these. photo_ids = np.array(photo_ids) classes = np.array(classes) # Count the number of each class. class_counts = Counter(classes) if autoskip: # Create a mask for the classes that have enough members and remove # the photo IDs that don't have enough members. mask = [] for i, c in enumerate(classes): if class_counts[c] >= k: mask.append(i) photo_ids = photo_ids[mask] classes = classes[mask] else: for label, count in class_counts.items(): assert count >= k, "Class {0} has only {1} members, which " \ "is too few. The minimum number of labels for any " \ "class cannot be less than k={2}. Use --autoskip to skip " \ "classes with too few members.".format(label, count, k) if autoskip: photo_count_min = k else: photo_count_min = 0 # Train data exporter. train_data = BatchMakeTrainData(self.config, self.cache_dir) train_data.set_photo_count_min(photo_count_min) # Set the trainer. trainer = BatchMakeAnn(self.config) trainer.set_photo_count_min(photo_count_min) if self.aivolver_config_path: trainer.set_training_method('aivolver', self.aivolver_config_path) # Set the ANN tester. tester = TestAnn(self.config) tester.set_photo_count_min(photo_count_min) # Obtain cross validation folds. folds = cross_validation.StratifiedKFold(classes, k) result_dir = os.path.join(self.temp_dir, 'results') for i, (train_idx, test_idx) in enumerate(folds): # Make data directories. train_dir = os.path.join(self.temp_dir, 'train', str(i)) test_dir = os.path.join(self.temp_dir, 'test', str(i)) ann_dir = os.path.join(self.temp_dir, 'ann', str(i)) test_result = os.path.join(result_dir, '{0}.tsv'.format(i)) for path in (train_dir, test_dir, ann_dir, result_dir): if not os.path.isdir(path): os.makedirs(path) # Make train data for this fold. train_samples = photo_ids[train_idx] train_data.set_subset(train_samples) train_data.batch_export(train_dir) # Make test data for this fold. test_samples = photo_ids[test_idx] train_data.set_subset(test_samples) train_data.batch_export(test_dir, train_dir) # Train neural networks on training data. trainer.batch_train(data_dir=train_dir, output_dir=ann_dir) # Calculate the score for this fold. tester.test_with_hierarchy(test_dir, ann_dir) tester.export_hierarchy_results(test_result) # List all level combinations. try: class_hr = self.config.classification.hierarchy hr = [level.name for level in class_hr] except: raise ConfigurationError("classification hierarchy not set") level_filters = [] ranks = [] for i in range(len(hr)): ranks.append(hr[i]) level_filters.append(ranks) level_filters = tuple(level_filters) for filter_ in level_filters: correct, total = tester.get_correct_count(filter_) score = float(correct) / total filter_s = "/".join(filter_) if filter_s not in scores: scores[filter_s] = [] scores[filter_s].append(score) return scores
def export(self, filename, filter_, config=None): """Write the training data to `filename`. Images to be processed are obtained from the database. Which images are obtained and with which classes is set by the filter `filter_`. Image fingerprints are obtained from cache, which must have been created for configuration `config` or `self.config`. """ session, metadata = db.get_session_or_error() if not conf.force_overwrite and os.path.isfile(filename): raise FileExistsError(filename) # Get the classification categories from the database. classes = db.get_classes_from_filter(session, metadata, filter_) assert len(classes) > 0, \ "No classes found for filter `%s`" % filter_ # Get the photos and corresponding classification using the filter. images = db.get_filtered_photos_with_taxon(session, metadata, filter_) images = images.all() if not images: logging.info("No images found for the filter `%s`", filter_) return if self.get_photo_count_min(): assert len(images) >= self.get_photo_count_min(), \ "Expected to find at least photo_count_min={0} photos, found " \ "{1}".format(self.get_photo_count_min(), len(images)) # Calculate the number of images that will be processed, taking into # account the subset. photo_ids = np.array([photo.id for photo, _ in images]) if self.subset: n_images = len(np.intersect1d(list(photo_ids), list(self.subset))) else: n_images = len(images) logging.info("Going to process %d photos...", n_images) # Make a codeword for each class. codewords = get_codewords(classes) # Construct the header. header_data, header_out = self.__make_header(len(classes)) header = ["ID"] + header_data + header_out # Get the configurations. if not config: config = self.config # Load the fingerprint cache. self.cache.load_cache(self.cache_path, config) # Generate the training data. with open(filename, 'w') as fh: # Write the header. fh.write( "%s\n" % "\t".join(header) ) # Set the training data. training_data = TrainData(len(header_data), len(classes)) for photo, class_ in images: # Only export the subset if an export subset is set. if self.subset and photo.id not in self.subset: continue logging.info("Processing `%s` of class `%s`...", photo.path, class_) # Get phenotype for this image from the cache. phenotype = self.cache.get_phenotype(photo.md5sum) assert len(phenotype) == len(header_data), \ "Fingerprint size mismatch. According to the header " \ "there are {0} data columns, but the fingerprint has " \ "{1}".format(len(header_data), len(phenotype)) training_data.append(phenotype, codewords[class_], label=photo.id) training_data.finalize() if not training_data: raise ValueError("Training data cannot be empty") # Round feature data. training_data.round_input(6) # Write data rows. for photo_id, input_, output in training_data: row = [str(photo_id)] row.extend(input_.astype(str)) row.extend(output.astype(str)) fh.write("%s\n" % "\t".join(row)) logging.info("Training data written to %s", filename)
def export(self, filename, filter_, config=None, codebook_file=None): """Write the training data to `filename`. Images to be processed are obtained from the database. Which images are obtained and with which classes is set by the filter `filter_`. Image fingerprints are obtained from cache, which must have been created for configuration `config` or `self.config`. """ session, metadata = db.get_session_or_error() if not conf.force_overwrite and os.path.isfile(filename): raise FileExistsError(filename) # Get the classification categories from the database. classes = db.get_classes_from_filter(session, metadata, filter_) assert len(classes) > 0, \ "No classes found for filter `%s`" % filter_ # Get the photos and corresponding classification using the filter. images = db.get_filtered_photos_with_taxon(session, metadata, filter_) images = images.all() if not images: logging.info("No images found for the filter `%s`", filter_) return if self.get_photo_count_min(): assert len(images) >= self.get_photo_count_min(), \ "Expected to find at least photo_count_min={0} photos, found " \ "{1}".format(self.get_photo_count_min(), len(images)) # Calculate the number of images that will be processed, taking into # account the subset. photo_ids = np.array([photo.id for photo, _ in images]) if self.subset: n_images = len(np.intersect1d(list(photo_ids), list(self.subset))) else: n_images = len(images) logging.info("Going to process %d photos...", n_images) # Make a codeword for each class. codewords = get_codewords(classes) # Construct the header. header_data, header_out = self.__make_header(len(classes)) header = ["ID"] + header_data + header_out # Get the configurations. if not config: config = self.config # Load the fingerprint cache. self.cache.load_cache(self.cache_path, config) # Check if the BagOfWords alogrithm needs to be applied. use_bow = getattr(self.config.features['surf'], 'bow_clusters', False) if use_bow and codebook_file == None: codebook = self.__make_codebook(images, filename) elif use_bow: with open(codebook_file, "rb") as cb: codebook = load(cb) # Generate the training data. with open(filename, 'w') as fh: # Write the header. fh.write("%s\n" % "\t".join(header)) # Set the training data. training_data = TrainData(len(header_data), len(classes)) for photo, class_ in images: # Only export the subset if an export subset is set. if self.subset and photo.id not in self.subset: continue logging.info("Processing `%s` of class `%s`...", photo.path, class_) # Get phenotype for this image from the cache. phenotype = self.cache.get_phenotype(photo.md5sum) # If the BagOfWords algorithm is applied, # convert phenotype to BOW-code. if use_bow: phenotype = get_bowcode_from_surf_features( phenotype, codebook) assert len(phenotype) == len(header_data), \ "Fingerprint size mismatch. According to the header " \ "there are {0} data columns, but the fingerprint has " \ "{1}".format(len(header_data), len(phenotype)) training_data.append(phenotype, codewords[class_], label=photo.id) training_data.finalize() if not training_data: raise ValueError("Training data cannot be empty") # Round feature data only if BOW is not applied. if not use_bow: training_data.round_input(6) # Write data rows. for photo_id, input_, output in training_data: row = [str(photo_id)] row.extend(input_.astype(str)) row.extend(output.astype(str)) fh.write("%s\n" % "\t".join(row)) logging.info("Training data written to %s", filename)