def get_default_config(cls): c = super(SmqtkClassifierService, cls).get_default_config() c[cls.CONFIG_ENABLE_CLASSIFIER_REMOVAL] = False # Static classifier configurations c[cls.CONFIG_CLASSIFIER_COLLECTION] = \ ClassifierCollection.get_default_config() # Classification element factory for new classification results. c[cls.CONFIG_CLASSIFICATION_FACTORY] = \ ClassificationElementFactory.get_default_config() # Descriptor generator for new content c[cls.CONFIG_DESCRIPTOR_GENERATOR] = smqtk.utils.plugin.make_config( get_descriptor_generator_impls() ) # Descriptor factory for new content descriptors c[cls.CONFIG_DESCRIPTOR_FACTORY] = \ DescriptorElementFactory.get_default_config() # from-IQR-state *supervised* classifier configuration c[cls.CONFIG_IQR_CLASSIFIER] = smqtk.utils.plugin.make_config( get_classifier_impls( sub_interface=SupervisedClassifier ) ) c[cls.CONFIG_IMMUTABLE_LABELS] = [] return c
def default_config(): return { 'plugins': { 'classifier': plugin.make_config(get_classifier_impls()), 'classification_factory': ClassificationElementFactory.get_default_config(), 'descriptor_index': plugin.make_config(get_descriptor_index_impls()) }, 'utility': { 'train': False, 'csv_filepath': 'CHAMGEME :: PATH :: a csv file', 'output_plot_pr': None, 'output_plot_roc': None, 'output_plot_confusion_matrix': None, 'output_uuid_confusion_matrix': None, 'curve_confidence_interval': False, 'curve_confidence_interval_alpha': 0.4, }, "parallelism": { "descriptor_fetch_cores": 4, "classification_cores": None, }, }
def get_default_config(cls): c = super(IqrService, cls).get_default_config() c_rel_index = plugin.make_config( get_relevancy_index_impls() ) merge_dict(c_rel_index, iqr_session.DFLT_REL_INDEX_CONFIG) merge_dict(c, { "iqr_service": { "positive_seed_neighbors": 500, "plugin_notes": { "relevancy_index_config": "The relevancy index config provided should not have " "persistent storage configured as it will be used in " "such a way that instances are created, built and " "destroyed often.", "descriptor_index": "This is the index from which given positive and " "negative example descriptors are retrieved from. " "Not used for nearest neighbor querying. " "This index must contain all descriptors that could " "possibly be used as positive/negative examples and " "updated accordingly.", "neighbor_index": "This is the neighbor index to pull initial near-" "positive descriptors from.", "classifier_config": "The configuration to use for training and using " "classifiers for the /classifier endpoint. " "When configuring a classifier for use, don't fill " "out model persistence values as many classifiers " "may be created and thrown away during this service's " "operation.", "classification_factory": "Selection of the backend in which classifications " "are stored. The in-memory version is recommended " "because normal caching mechanisms will not account " "for the variety of classifiers that can potentially " "be created via this utility.", }, "plugins": { "relevancy_index_config": c_rel_index, "descriptor_index": plugin.make_config( get_descriptor_index_impls() ), "neighbor_index": plugin.make_config(get_nn_index_impls()), "classifier_config": plugin.make_config(get_classifier_impls()), "classification_factory": ClassificationElementFactory.get_default_config(), } } }) return c
def get_default_config(cls): c = super(IqrService, cls).get_default_config() c_rel_index = plugin.make_config(get_relevancy_index_impls()) merge_dict(c_rel_index, iqr_session.DFLT_REL_INDEX_CONFIG) merge_dict( c, { "iqr_service": { "positive_seed_neighbors": 500, "plugin_notes": { "relevancy_index_config": "The relevancy index config provided should not have " "persistent storage configured as it will be used in " "such a way that instances are created, built and " "destroyed often.", "descriptor_index": "This is the index from which given positive and " "negative example descriptors are retrieved from. " "Not used for nearest neighbor querying. " "This index must contain all descriptors that could " "possibly be used as positive/negative examples and " "updated accordingly.", "neighbor_index": "This is the neighbor index to pull initial near-" "positive descriptors from.", "classifier_config": "The configuration to use for training and using " "classifiers for the /classifier endpoint. " "When configuring a classifier for use, don't fill " "out model persistence values as many classifiers " "may be created and thrown away during this service's " "operation.", "classification_factory": "Selection of the backend in which classifications " "are stored. The in-memory version is recommended " "because normal caching mechanisms will not account " "for the variety of classifiers that can potentially " "be created via this utility.", }, "plugins": { "relevancy_index_config": c_rel_index, "descriptor_index": plugin.make_config(get_descriptor_index_impls()), "neighbor_index": plugin.make_config(get_nn_index_impls()), "classifier_config": plugin.make_config(get_classifier_impls()), "classification_factory": ClassificationElementFactory.get_default_config(), } } }) return c
def get_default_config(): return { "descriptor_factory": DescriptorElementFactory.get_default_config(), "descriptor_generator": plugin.make_config(get_descriptor_generator_impls()), "classification_factory": ClassificationElementFactory.get_default_config(), "classifier": plugin.make_config(get_classifier_impls()), }
def default_config(): return { "utility": { "classify_overwrite": False, "parallel": { "use_multiprocessing": False, "index_extraction_cores": None, "classification_cores": None, } }, "plugins": { "classifier": plugin.make_config(get_classifier_impls()), "classification_factory": plugin.make_config(get_classification_element_impls()), "descriptor_index": plugin.make_config(get_descriptor_index_impls()), } }
def train_classifier_iqr(config, iqr_state_fp): #: :type: smqtk.algorithms.SupervisedClassifier classifier = from_plugin_config( config['classifier'], get_classifier_impls(sub_interface=SupervisedClassifier) ) # Load state into an empty IqrSession instance. with open(iqr_state_fp, 'rb') as f: state_bytes = f.read().strip() descr_factory = DescriptorElementFactory(DescriptorMemoryElement, {}) iqrs = IqrSession() iqrs.set_state_bytes(state_bytes, descr_factory) # Positive descriptor examples for training are composed of those from # external and internal sets. Same for negative descriptor examples. pos = iqrs.positive_descriptors | iqrs.external_positive_descriptors neg = iqrs.negative_descriptors | iqrs.external_negative_descriptors classifier.train(class_examples={'positive': pos, 'negative': neg})
def train_classifier_iqr(config, iqr_state_fp): log = logging.getLogger(__name__) #: :type: smqtk.algorithms.SupervisedClassifier classifier = from_plugin_config(config['classifier'], get_classifier_impls()) if not isinstance(classifier, SupervisedClassifier): raise RuntimeError("Configured classifier must be of the " "SupervisedClassifier type in order to train.") # Get pos/neg descriptors out of iqr state zip z_file = open(iqr_state_fp, 'r') z = zipfile.ZipFile(z_file) if len(z.namelist()) != 1: raise RuntimeError("Invalid IqrState file!") iqrs = json.loads(z.read(z.namelist()[0])) if len(iqrs) != 2: raise RuntimeError("Invalid IqrState file!") if 'pos' not in iqrs or 'neg' not in iqrs: raise RuntimeError("Invalid IqrState file!") log.info("Loading pos/neg descriptors") #: :type: list[smqtk.representation.DescriptorElement] pos = [] #: :type: list[smqtk.representation.DescriptorElement] neg = [] i = 0 for v in set(map(tuple, iqrs['pos'])): d = DescriptorMemoryElement('train', i) d.set_vector(numpy.array(v)) pos.append(d) i += 1 for v in set(map(tuple, iqrs['neg'])): d = DescriptorMemoryElement('train', i) d.set_vector(numpy.array(v)) neg.append(d) i += 1 log.info(' positive -> %d', len(pos)) log.info(' negative -> %d', len(neg)) classifier.train(positive=pos, negative=neg)
def default_config(): return { "utility": { "classify_overwrite": False, "parallel": { "use_multiprocessing": False, "index_extraction_cores": None, "classification_cores": None, } }, "plugins": { "classifier": plugin.make_config(get_classifier_impls()), "classification_factory": plugin.make_config( get_classification_element_impls() ), "descriptor_index": plugin.make_config( get_descriptor_index_impls() ), } }
def get_supervised_classifier_impls(): return get_classifier_impls(sub_interface=SupervisedClassifier)
def get_default_config(): return { "classifier": make_config(get_classifier_impls()), }
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # # Initialize stuff from configuration # #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config(config['plugins']['classifier'], get_classifier_impls()) #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory']) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) uuid2label_filepath = config['utility']['csv_filepath'] do_train = config['utility']['train'] output_uuid_cm = config['utility']['output_uuid_confusion_matrix'] plot_filepath_pr = config['utility']['output_plot_pr'] plot_filepath_roc = config['utility']['output_plot_roc'] plot_filepath_cm = config['utility']['output_plot_confusion_matrix'] plot_ci = config['utility']['curve_confidence_interval'] plot_ci_alpha = config['utility']['curve_confidence_interval_alpha'] # # Construct mapping of label to the DescriptorElement instances for that # described by that label. # log.info("Loading descriptors by UUID") def iter_uuid_label(): """ Iterate through UUIDs in specified file """ with open(uuid2label_filepath) as uuid2label_file: reader = csv.reader(uuid2label_file) for r in reader: # TODO: This will need to be updated to handle multiple labels # per descriptor. yield r[0], r[1] def get_descr(r): """ Fetch descriptors from configured index """ uuid, truth_label = r return truth_label, descriptor_index.get_descriptor(uuid) tlabel_element_iter = parallel.parallel_map( get_descr, iter_uuid_label(), name="cmv_get_descriptors", use_multiprocessing=True, cores=config['parallelism']['descriptor_fetch_cores'], ) # Map of truth labels to descriptors of labeled data #: :type: dict[str, list[smqtk.representation.DescriptorElement]] tlabel2descriptors = {} for tlabel, d in tlabel_element_iter: tlabel2descriptors.setdefault(tlabel, []).append(d) # Train classifier if the one given has a ``train`` method and training # was turned enabled. if do_train: if isinstance(classifier, SupervisedClassifier): log.info("Training classifier model") classifier.train(tlabel2descriptors) exit(0) else: ValueError("Configured classifier is not a SupervisedClassifier " "type and does not support training.") # # Apply classifier to descriptors for predictions # # Truth label to predicted classification results #: :type: dict[str, set[smqtk.representation.ClassificationElement]] tlabel2classifications = {} for tlabel, descriptors in tlabel2descriptors.items(): tlabel2classifications[tlabel] = \ set(classifier.classify_async( descriptors, classification_factory, use_multiprocessing=True, procs=config['parallelism']['classification_cores'], ri=1.0, ).values()) log.info("Truth label counts:") for l in sorted(tlabel2classifications): log.info(" %s :: %d", l, len(tlabel2classifications[l])) # # Confusion Matrix # conf_mat, labels = gen_confusion_matrix(tlabel2classifications) log.info("Confusion_matrix") log_cm(log.info, conf_mat, labels) if plot_filepath_cm: plot_cm(conf_mat, labels, plot_filepath_cm) # CM of descriptor UUIDs to output json if output_uuid_cm: # Top dictionary keys are true labels, inner dictionary keys are UUID # predicted labels. log.info("Computing UUID Confusion Matrix") #: :type: dict[str, dict[str, set | list]] uuid_cm = {} for tlabel in tlabel2classifications: uuid_cm[tlabel] = collections.defaultdict(set) for c in tlabel2classifications[tlabel]: uuid_cm[tlabel][c.max_label()].add(c.uuid) # convert sets to lists for plabel in uuid_cm[tlabel]: uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel]) with open(output_uuid_cm, 'w') as f: log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm) json.dump(uuid_cm, f, indent=2, separators=(',', ': ')) # # Create PR/ROC curves via scikit learn tools # if plot_filepath_pr: log.info("Making PR curve") make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci, plot_ci_alpha) if plot_filepath_roc: log.info("Making ROC curve") make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci, plot_ci_alpha)
def classify_files(config, label, file_globs): log = logging.getLogger(__name__) #: :type: smqtk.algorithms.Classifier classifier = \ plugin.from_plugin_config(config['classifier'], get_classifier_impls()) def log_avaialable_labels(): log.info("Available classifier labels:") for l in classifier.get_labels(): log.info("- %s", l) if label is None: log_avaialable_labels() return elif label not in classifier.get_labels(): log.error("Invalid classification label provided to compute and filter " "on: '%s'", label) log_avaialable_labels() return log.info("Collecting files from globs") #: :type: list[DataFileElement] data_elements = [] uuid2filepath = {} for g in file_globs: if os.path.isfile(g): d = DataFileElement(g) data_elements.append(d) uuid2filepath[d.uuid()] = g else: log.debug("expanding glob: %s", g) for fp in glob.iglob(g): d = DataFileElement(fp) data_elements.append(d) uuid2filepath[d.uuid()] = fp if not data_elements: raise RuntimeError("No files provided for classification.") log.info("Computing descriptors") descriptor_factory = \ DescriptorElementFactory.from_config(config['descriptor_factory']) #: :type: smqtk.algorithms.DescriptorGenerator descriptor_generator = \ plugin.from_plugin_config(config['descriptor_generator'], get_descriptor_generator_impls()) descr_map = descriptor_generator\ .compute_descriptor_async(data_elements, descriptor_factory) log.info("Classifying descriptors") classification_factory = ClassificationElementFactory \ .from_config(config['classification_factory']) classification_map = classifier\ .classify_async(list(descr_map.values()), classification_factory) log.info("Printing input file paths that classified as the given label.") # map of UUID to filepath: uuid2c = dict((c.uuid, c) for c in six.itervalues(classification_map)) for data in data_elements: d_uuid = data.uuid() log.debug("'{}' classification map: {}".format( uuid2filepath[d_uuid], uuid2c[d_uuid].get_classification() )) if uuid2c[d_uuid].max_label() == label: print(uuid2filepath[d_uuid])
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # - parallel_map UUIDs to load from the configured index # - classify iterated descriptors uuids_list_filepath = args.uuids_list output_csv_filepath = args.csv_data output_csv_header_filepath = args.csv_header classify_overwrite = config['utility']['classify_overwrite'] p_use_multiprocessing = \ config['utility']['parallel']['use_multiprocessing'] p_index_extraction_cores = \ config['utility']['parallel']['index_extraction_cores'] p_classification_cores = \ config['utility']['parallel']['classification_cores'] if not uuids_list_filepath: raise ValueError("No uuids_list_filepath specified.") elif not os.path.isfile(uuids_list_filepath): raise ValueError("Given uuids_list_filepath did not point to a file.") if output_csv_header_filepath is None: raise ValueError("Need a path to save CSV header labels") if output_csv_filepath is None: raise ValueError("Need a path to save CSV data.") # # Initialize configured plugins # log.info("Initializing descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) log.info("Initializing classification factory") c_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory']) log.info("Initializing classifier") #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config(config['plugins']['classifier'], get_classifier_impls()) # # Setup/Process # def iter_uuids(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() def descr_for_uuid(uuid): """ :type uuid: collections.Hashable :rtype: smqtk.representation.DescriptorElement """ return descriptor_index.get_descriptor(uuid) def classify_descr(d): """ :type d: smqtk.representation.DescriptorElement :rtype: smqtk.representation.ClassificationElement """ return classifier.classify(d, c_factory, classify_overwrite) log.info("Initializing uuid-to-descriptor parallel map") #: :type: collections.Iterable[smqtk.representation.DescriptorElement] element_iter = parallel.parallel_map( descr_for_uuid, iter_uuids(), use_multiprocessing=p_use_multiprocessing, cores=p_index_extraction_cores, name="descr_for_uuid", ) log.info("Initializing descriptor-to-classification parallel map") #: :type: collections.Iterable[smqtk.representation.ClassificationElement] classification_iter = parallel.parallel_map( classify_descr, element_iter, use_multiprocessing=p_use_multiprocessing, cores=p_classification_cores, name='classify_descr', ) # # Write/Output files # c_labels = classifier.get_labels() def make_row(c): """ :type c: smqtk.representation.ClassificationElement """ c_m = c.get_classification() return [c.uuid] + [c_m[l] for l in c_labels] # column labels file log.info("Writing CSV column header file: %s", output_csv_header_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath)) with open(output_csv_header_filepath, 'wb') as f_csv: w = csv.writer(f_csv) w.writerow(['uuid'] + c_labels) # CSV file log.info("Writing CSV data file: %s", output_csv_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_filepath)) r_state = [0] * 7 with open(output_csv_filepath, 'wb') as f_csv: w = csv.writer(f_csv) for c in classification_iter: w.writerow(make_row(c)) bin_utils.report_progress(log.info, r_state, 1.0) # Final report r_state[1] -= 1 bin_utils.report_progress(log.info, r_state, 0) log.info("Done")
plt.title("PR - HT Positive") plt.xlabel("Recall") plt.ylabel("Precision") plt.legend(loc='best', fancybox=True, framealpha=0.5) plt.savefig(PLOT_PR_OUTPUT) else: # Using the final trained classifier with open(CLASSIFIER_TRAINING_CONFIG_JSON) as f: classifier_config = json.load(f) log.info("Loading plugins") descriptor_index = MemoryDescriptorIndex(file_cache=DESCRIPTOR_INDEX_FILE_CACHE) #: :type: smqtk.algorithms.Classifier classifier = from_plugin_config(classifier_config['plugins']['classifier'], get_classifier_impls()) c_factory = ClassificationElementFactory(MemoryClassificationElement, {}) #: :type: dict[str, list[str]] phone2shas = json.load(open(PHONE_SHA1_JSON)) #: :type: dict[str, float] phone2score = {} log.info("Classifying phone imagery descriptors") i = 0 descriptor_index_shas = set(descriptor_index.iterkeys()) for p in phone2shas: log.info('%s (%d / %d)', p, i + 1, len(phone2shas)) # Not all source "images" have descriptors since some URLs returned # non-image files. Intersect phone sha's with what was actually # computed. Warn if this reduces descriptors for classification to zero.
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # # Initialize stuff from configuration # #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config( config['plugins']['classifier'], get_classifier_impls() ) #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) uuid2label_filepath = config['utility']['csv_filepath'] do_train = config['utility']['train'] output_uuid_cm = config['utility']['output_uuid_confusion_matrix'] plot_filepath_pr = config['utility']['output_plot_pr'] plot_filepath_roc = config['utility']['output_plot_roc'] plot_filepath_cm = config['utility']['output_plot_confusion_matrix'] plot_ci = config['utility']['curve_confidence_interval'] plot_ci_alpha = config['utility']['curve_confidence_interval_alpha'] # # Construct mapping of label to the DescriptorElement instances for that # described by that label. # log.info("Loading descriptors by UUID") def iter_uuid_label(): """ Iterate through UUIDs in specified file """ with open(uuid2label_filepath) as uuid2label_file: reader = csv.reader(uuid2label_file) for r in reader: # TODO: This will need to be updated to handle multiple labels # per descriptor. yield r[0], r[1] def get_descr(r): """ Fetch descriptors from configured index """ uuid, truth_label = r return truth_label, descriptor_index.get_descriptor(uuid) tlabel_element_iter = parallel.parallel_map( get_descr, iter_uuid_label(), name="cmv_get_descriptors", use_multiprocessing=True, cores=config['parallelism']['descriptor_fetch_cores'], ) # Map of truth labels to descriptors of labeled data #: :type: dict[str, list[smqtk.representation.DescriptorElement]] tlabel2descriptors = {} for tlabel, d in tlabel_element_iter: tlabel2descriptors.setdefault(tlabel, []).append(d) # Train classifier if the one given has a ``train`` method and training # was turned enabled. if do_train: if isinstance(classifier, SupervisedClassifier): log.info("Training classifier model") classifier.train(tlabel2descriptors) exit(0) else: ValueError("Configured classifier is not a SupervisedClassifier " "type and does not support training.") # # Apply classifier to descriptors for predictions # # Truth label to predicted classification results #: :type: dict[str, set[smqtk.representation.ClassificationElement]] tlabel2classifications = {} for tlabel, descriptors in six.iteritems(tlabel2descriptors): tlabel2classifications[tlabel] = \ set(classifier.classify_async( descriptors, classification_factory, use_multiprocessing=True, procs=config['parallelism']['classification_cores'], ri=1.0, ).values()) log.info("Truth label counts:") for l in sorted(tlabel2classifications): log.info(" %s :: %d", l, len(tlabel2classifications[l])) # # Confusion Matrix # conf_mat, labels = gen_confusion_matrix(tlabel2classifications) log.info("Confusion_matrix") log_cm(log.info, conf_mat, labels) if plot_filepath_cm: plot_cm(conf_mat, labels, plot_filepath_cm) # Confusion Matrix of descriptor UUIDs to output json if output_uuid_cm: # Top dictionary keys are true labels, inner dictionary keys are UUID # predicted labels. log.info("Computing UUID Confusion Matrix") #: :type: dict[str, dict[collections.Hashable, set]] uuid_cm = {} for tlabel in tlabel2classifications: uuid_cm[tlabel] = collections.defaultdict(set) for c in tlabel2classifications[tlabel]: uuid_cm[tlabel][c.max_label()].add(c.uuid) # convert sets to lists for plabel in uuid_cm[tlabel]: uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel]) with open(output_uuid_cm, 'w') as f: log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm) json.dump(uuid_cm, f, indent=2, separators=(',', ': ')) # # Create PR/ROC curves via scikit learn tools # if plot_filepath_pr: log.info("Making PR curve") make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci, plot_ci_alpha) if plot_filepath_roc: log.info("Making ROC curve") make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci, plot_ci_alpha)
def add_iqr_state_classifier(self): """ Train a classifier based on the user-provided IQR state file bytes in a base64 encoding, matched with a descriptive label of that classifier's topic. Since all IQR session classifiers end up only having two result classes (positive and negative), the topic of the classifier is encoded in the descriptive label the user applies to the classifier. Below is an example call to this endpoint via the ``requests`` python module, showing how base64 data is sent:: import base64 import requests data_bytes = "Load some content bytes here." requests.get('http://localhost:5000/iqr_classifier', data={'bytes_b64': base64.b64encode(data_bytes), 'label': 'some_label'}) With curl on the command line:: $ curl -X POST localhost:5000/iqr_classifier \ -d "label=some_label" \ --data-urlencode "bytes_b64=$(base64 -w0 /path/to/file)" # If this fails, you may wish to encode the file separately and # use the file reference syntax instead: $ base64 -w0 /path/to/file > /path/to/file.b64 $ curl -X POST localhost:5000/iqr_classifier -d label=some_label \ --data-urlencode bytes_64@/path/to/file.b64 To lock this classifier and guard it against deletion, add "lock_label=true":: $ curl -X POST localhost:5000/iqr_classifier \ -d "label=some_label" \ -d "lock_label=true" \ --data-urlencode "bytes_b64=$(base64 -w0 /path/to/file)" Form arguments: iqr_state_b64 base64 encoding of the bytes of the IQR session state save file. label Descriptive label to apply to this classifier. This should not conflict with existing classifier labels. lock_label If 'true', disallow deletion of this label. If 'false', allow deletion of this label. Only has an effect if deletion is enabled for this service. (Default: 'false') Returns 201. """ data_b64 = flask.request.values.get('bytes_b64', default=None) label = flask.request.values.get('label', default=None) lock_clfr_str = flask.request.values.get('lock_label', default='false') if data_b64 is None or len(data_b64) == 0: return make_response_json("No state base64 data provided.", 400) elif label is None or len(label) == 0: return make_response_json("No descriptive label provided.", 400) try: lock_clfr = bool(flask.json.loads(lock_clfr_str)) except JSON_DECODE_EXCEPTION: return make_response_json("Invalid boolean value for" " 'lock_label'. Was given: '%s'" % lock_clfr_str, 400) try: # Using urlsafe version because it handles both regular and urlsafe # alphabets. data_bytes = base64.urlsafe_b64decode(data_b64.encode('utf-8')) except (TypeError, binascii.Error) as ex: return make_response_json("Invalid base64 input: %s" % str(ex)), \ 400 # If the given label conflicts with one already in the collection, # fail. if label in self.classifier_collection.labels(): return make_response_json( "Label already exists in classifier collection.", 400) # Create dummy IqrSession to extract pos/neg descriptors. iqrs = IqrSession() iqrs.set_state_bytes(data_bytes, self.descriptor_factory) pos = iqrs.positive_descriptors | iqrs.external_positive_descriptors neg = iqrs.negative_descriptors | iqrs.external_negative_descriptors del iqrs # Make a classifier instance from the stored config for IQR # session-based classifiers. #: :type: SupervisedClassifier classifier = smqtk.utils.plugin.from_plugin_config( self.iqr_state_classifier_config, get_classifier_impls(sub_interface=SupervisedClassifier) ) classifier.train(class_examples={'positive': pos, 'negative': neg}) try: self.classifier_collection.add_classifier(label, classifier) # If we're allowing deletions, get the lock flag from the form and # set it for this classifier if self.enable_classifier_removal and lock_clfr: self.immutable_labels.add(label) except ValueError as e: if e.args[0].find('JSON') > -1: return make_response_json("Tried to parse malformed JSON in " "form argument.", 400) return make_response_json("Duplicate label ('%s') added during " "classifier training of provided IQR " "session state." % label, 400, label=label) return make_response_json("Finished training IQR-session-based " "classifier for label '%s'." % label, 201, label=label)
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # - parallel_map UUIDs to load from the configured index # - classify iterated descriptors uuids_list_filepath = args.uuids_list output_csv_filepath = args.csv_data output_csv_header_filepath = args.csv_header classify_overwrite = config['utility']['classify_overwrite'] p_use_multiprocessing = \ config['utility']['parallel']['use_multiprocessing'] p_index_extraction_cores = \ config['utility']['parallel']['index_extraction_cores'] p_classification_cores = \ config['utility']['parallel']['classification_cores'] if not uuids_list_filepath: raise ValueError("No uuids_list_filepath specified.") elif not os.path.isfile(uuids_list_filepath): raise ValueError("Given uuids_list_filepath did not point to a file.") if output_csv_header_filepath is None: raise ValueError("Need a path to save CSV header labels") if output_csv_filepath is None: raise ValueError("Need a path to save CSV data.") # # Initialize configured plugins # log.info("Initializing descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) log.info("Initializing classification factory") c_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) log.info("Initializing classifier") #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config( config['plugins']['classifier'], get_classifier_impls() ) # # Setup/Process # def iter_uuids(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() def descr_for_uuid(uuid): """ :type uuid: collections.Hashable :rtype: smqtk.representation.DescriptorElement """ return descriptor_index.get_descriptor(uuid) def classify_descr(d): """ :type d: smqtk.representation.DescriptorElement :rtype: smqtk.representation.ClassificationElement """ return classifier.classify(d, c_factory, classify_overwrite) log.info("Initializing uuid-to-descriptor parallel map") #: :type: collections.Iterable[smqtk.representation.DescriptorElement] element_iter = parallel.parallel_map( descr_for_uuid, iter_uuids(), use_multiprocessing=p_use_multiprocessing, cores=p_index_extraction_cores, name="descr_for_uuid", ) log.info("Initializing descriptor-to-classification parallel map") #: :type: collections.Iterable[smqtk.representation.ClassificationElement] classification_iter = parallel.parallel_map( classify_descr, element_iter, use_multiprocessing=p_use_multiprocessing, cores=p_classification_cores, name='classify_descr', ) # # Write/Output files # c_labels = classifier.get_labels() def make_row(e): """ :type e: smqtk.representation.ClassificationElement """ c_m = e.get_classification() return [e.uuid] + [c_m[l] for l in c_labels] # column labels file log.info("Writing CSV column header file: %s", output_csv_header_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath)) with open(output_csv_header_filepath, 'wb') as f_csv: w = csv.writer(f_csv) w.writerow(['uuid'] + [str(cl) for cl in c_labels]) # CSV file log.info("Writing CSV data file: %s", output_csv_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_filepath)) r_state = [0] * 7 with open(output_csv_filepath, 'wb') as f_csv: w = csv.writer(f_csv) for c in classification_iter: w.writerow(make_row(c)) bin_utils.report_progress(log.info, r_state, 1.0) # Final report r_state[1] -= 1 bin_utils.report_progress(log.info, r_state, 0) log.info("Done")
def main(): description = """ Utility for validating a given classifier implementation's model against some labeled testing data, outputting PR and ROC curve plots with area-under-curve score values. This utility can optionally be used train a supervised classifier model if the given classifier model configuration does not exist and a second CSV file listing labeled training data is provided. Training will be attempted if ``train`` is set to true. If training is performed, we exit after training completes. A ``SupervisedClassifier`` sub-classing implementation must be configured We expect the test and train CSV files in the column format: ... <UUID>,<label> ... The UUID is of the descriptor to which the label applies. The label may be any arbitrary string value, but all labels must be consistent in application. Some metrics presented assume the highest confidence class as the single predicted class for an element: - confusion matrix The output UUID confusion matrix is a JSON dictionary where the top-level keys are the true labels, and the inner dictionary is the mapping of predicted labels to the UUIDs of the classifications/descriptors that yielded the prediction. Again, this is based on the maximum probability label for a classification result (T=0.5). """ args, config = bin_utils.utility_main_helper(default_config, description) log = logging.getLogger(__name__) # # Initialize stuff from configuration # #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config( config['plugins']['classifier'], get_classifier_impls() ) #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) uuid2label_filepath = config['utility']['csv_filepath'] do_train = config['utility']['train'] output_uuid_cm = config['utility']['output_uuid_confusion_matrix'] plot_filepath_pr = config['utility']['output_plot_pr'] plot_filepath_roc = config['utility']['output_plot_roc'] plot_filepath_cm = config['utility']['output_plot_confusion_matrix'] plot_ci = config['utility']['curve_confidence_interval'] plot_ci_alpha = config['utility']['curve_confidence_interval_alpha'] # # Construct mapping of label to the DescriptorElement instances for that # described by that label. # log.info("Loading descriptors by UUID") def iter_uuid_label(): """ Iterate through UUIDs in specified file """ with open(uuid2label_filepath) as uuid2label_file: reader = csv.reader(uuid2label_file) for r in reader: # TODO: This will need to be updated to handle multiple labels # per descriptor. yield r[0], r[1] def get_descr(r): """ Fetch descriptors from configured index """ uuid, truth_label = r return truth_label, descriptor_index.get_descriptor(uuid) tlabel_element_iter = parallel.parallel_map( get_descr, iter_uuid_label(), name="cmv_get_descriptors", use_multiprocessing=True, cores=config['parallelism']['descriptor_fetch_cores'], ) # Map of truth labels to descriptors of labeled data #: :type: dict[str, list[smqtk.representation.DescriptorElement]] tlabel2descriptors = {} for tlabel, d in tlabel_element_iter: tlabel2descriptors.setdefault(tlabel, []).append(d) # Train classifier if the one given has a ``train`` method and training # was turned enabled. if do_train: if isinstance(classifier, SupervisedClassifier): log.info("Training classifier model") classifier.train(tlabel2descriptors) exit(0) else: ValueError("Configured classifier is not a SupervisedClassifier " "type and does not support training.") # # Apply classifier to descriptors for predictions # # Truth label to predicted classification results #: :type: dict[str, set[smqtk.representation.ClassificationElement]] tlabel2classifications = {} for tlabel, descriptors in tlabel2descriptors.iteritems(): tlabel2classifications[tlabel] = \ set(classifier.classify_async( descriptors, classification_factory, use_multiprocessing=True, procs=config['parallelism']['classification_cores'], ri=1.0, ).values()) log.info("Truth label counts:") for l in sorted(tlabel2classifications): log.info(" %s :: %d", l, len(tlabel2classifications[l])) # # Confusion Matrix # conf_mat, labels = gen_confusion_matrix(tlabel2classifications) log.info("Confusion_matrix") log_cm(log.info, conf_mat, labels) if plot_filepath_cm: plot_cm(conf_mat, labels, plot_filepath_cm) # CM of descriptor UUIDs to output json if output_uuid_cm: # Top dictionary keys are true labels, inner dictionary keys are UUID # predicted labels. log.info("Computing UUID Confusion Matrix") #: :type: dict[str, dict[str, set | list]] uuid_cm = {} for tlabel in tlabel2classifications: uuid_cm[tlabel] = collections.defaultdict(set) for c in tlabel2classifications[tlabel]: uuid_cm[tlabel][c.max_label()].add(c.uuid) # convert sets to lists for plabel in uuid_cm[tlabel]: uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel]) with open(output_uuid_cm, 'w') as f: log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm) json.dump(uuid_cm, f, indent=2, separators=(',', ': ')) # # Create PR/ROC curves via scikit learn tools # if plot_filepath_pr: log.info("Making PR curve") make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci, plot_ci_alpha) if plot_filepath_roc: log.info("Making ROC curve") make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci, plot_ci_alpha)
def classify(self): """ Given a refined session ID and some number of descriptor UUIDs, create a classifier according to the current state and classify the given descriptors adjudicated. This will fail if the session has not been given adjudications (refined) yet. URI Args: sid UUID of the session to utilize uuids List of descriptor UUIDs to classify. Return list of results will be in the same order as this list. """ # Record clean/dirty status after making classifier/refining so we # don't train a new classifier when we don't have to. sid = flask.request.args.get('sid', None) uuids = flask.request.args.get('uuids', None) try: uuids = json.loads(uuids) except ValueError: return make_response_json( "Failed to decode uuids as json. Given '%s'" % uuids), 400 if sid is None: return make_response_json("No session id (sid) provided"), 400 if not uuids: return make_response_json( "No descriptor UUIDs provided", sid=sid, ), 400 try: with self.controller.get_session(sid) as iqrs: if not iqrs.positive_descriptors: return make_response_json( "No positive labels in current session", sid=sid), 400 if not iqrs.negative_descriptors: return make_response_json( "No negative labels in current session", sid=sid), 400 # Get descriptor elements for classification try: descriptors = list( self.descriptor_index.get_many_descriptors(uuids)) except KeyError, ex: err_uuid = str(ex) self._log.warn(traceback.format_exc()) return make_response_json( "Descriptor UUID '%s' cannot be found in the " "configured descriptor index." % err_uuid, sid=sid, uuid=err_uuid, ), 404 classifier = self.session_classifiers.get(sid, None) pos_label = "positive" neg_label = "negative" if self.session_classifier_dirty[sid] or classifier is None: self._log.debug("Training new classifier for current " "refine state") #: :type: SupervisedClassifier classifier = plugin.from_plugin_config( self.classifier_config, get_classifier_impls( sub_interface=SupervisedClassifier)) classifier.train({ pos_label: iqrs.positive_descriptors, neg_label: iqrs.negative_descriptors }) self.session_classifiers[sid] = classifier self.session_classifier_dirty[sid] = False classifications = classifier.classify_async( descriptors, self.classification_factory, use_multiprocessing=True, ri=1.0) # Format output to be parallel lists of UUIDs input and # positive class classification scores. o_uuids = [] o_proba = [] for d in descriptors: o_uuids.append(d.uuid()) o_proba.append(classifications[d][pos_label]) assert uuids == o_uuids, \ "Output UUID list is not congruent with INPUT list." return make_response_json( "Finished classification", sid=sid, uuids=o_uuids, proba=o_proba, ), 200 except KeyError: return make_response_json("session id '%s' not found" % sid, sid=sid), 404
plt.xlabel("Recall") plt.ylabel("Precision") plt.legend(loc='best', fancybox=True, framealpha=0.5) plt.savefig(PLOT_PR_OUTPUT) else: # Using the final trained classifier with open(CLASSIFIER_TRAINING_CONFIG_JSON) as f: classifier_config = json.load(f) log.info("Loading plugins") descriptor_index = MemoryDescriptorIndex( file_cache=DESCRIPTOR_INDEX_FILE_CACHE) #: :type: smqtk.algorithms.Classifier classifier = from_plugin_config(classifier_config['plugins']['classifier'], get_classifier_impls()) c_factory = ClassificationElementFactory(MemoryClassificationElement, {}) #: :type: dict[str, list[str]] phone2shas = json.load(open(PHONE_SHA1_JSON)) #: :type: dict[str, float] phone2score = {} log.info("Classifying phone imagery descriptors") i = 0 descriptor_index_shas = set(descriptor_index.iterkeys()) for p in phone2shas: log.info('%s (%d / %d)', p, i + 1, len(phone2shas)) # Not all source "images" have descriptors since some URLs returned # non-image files. Intersect phone sha's with what was actually # computed. Warn if this reduces descriptors for classification to zero.
sid=sid, uuid=err_uuid, ), 404 classifier = self.session_classifiers.get(sid, None) pos_label = "positive" neg_label = "negative" if self.session_classifier_dirty[sid] or classifier is None: self._log.debug("Training new classifier for current " "refine state") #: :type: SupervisedClassifier classifier = plugin.from_plugin_config( self.classifier_config, get_classifier_impls(sub_interface=SupervisedClassifier)) classifier.train({ pos_label: iqrs.positive_descriptors, neg_label: iqrs.negative_descriptors }) self.session_classifiers[sid] = classifier self.session_classifier_dirty[sid] = False classifications = classifier.classify_async( descriptors, self.classification_factory, use_multiprocessing=True, ri=1.0) # Format output to be parallel lists of UUIDs input and
def main(): description = """ Script for asynchronously computing classifications for DescriptorElements in a DescriptorIndex specified via a list of UUIDs. Results are output to a CSV file in the format: uuid, label1_confidence, label2_confidence, ... CSV columns labels are output to the given CSV header file path. Label columns will be in the order as reported by the classifier implementations ``get_labels`` method. Due to using an input file-list of UUIDs, we require that the UUIDs of indexed descriptors be strings, or equality comparable to the UUIDs' string representation. """ args, config = bin_utils.utility_main_helper( default_config, description, extend_parser, ) log = logging.getLogger(__name__) # - parallel_map UUIDs to load from the configured index # - classify iterated descriptors uuids_list_filepath = args.uuids_list output_csv_filepath = args.csv_data output_csv_header_filepath = args.csv_header classify_overwrite = config['utility']['classify_overwrite'] p_use_multiprocessing = \ config['utility']['parallel']['use_multiprocessing'] p_index_extraction_cores = \ config['utility']['parallel']['index_extraction_cores'] p_classification_cores = \ config['utility']['parallel']['classification_cores'] if not uuids_list_filepath: raise ValueError("No uuids_list_filepath specified.") elif not os.path.isfile(uuids_list_filepath): raise ValueError("Given uuids_list_filepath did not point to a file.") if output_csv_header_filepath is None: raise ValueError("Need a path to save CSV header labels") if output_csv_filepath is None: raise ValueError("Need a path to save CSV data.") # # Initialize configured plugins # log.info("Initializing descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) log.info("Initializing classification factory") c_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) log.info("Initializing classifier") #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config( config['plugins']['classifier'], get_classifier_impls() ) # # Setup/Process # def iter_uuids(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() def descr_for_uuid(uuid): """ :type uuid: collections.Hashable :rtype: smqtk.representation.DescriptorElement """ return descriptor_index.get_descriptor(uuid) def classify_descr(d): """ :type d: smqtk.representation.DescriptorElement :rtype: smqtk.representation.ClassificationElement """ return classifier.classify(d, c_factory, classify_overwrite) log.info("Initializing uuid-to-descriptor parallel map") #: :type: collections.Iterable[smqtk.representation.DescriptorElement] element_iter = parallel.parallel_map( descr_for_uuid, iter_uuids(), use_multiprocessing=p_use_multiprocessing, cores=p_index_extraction_cores, name="descr_for_uuid", ) log.info("Initializing descriptor-to-classification parallel map") #: :type: collections.Iterable[smqtk.representation.ClassificationElement] classification_iter = parallel.parallel_map( classify_descr, element_iter, use_multiprocessing=p_use_multiprocessing, cores=p_classification_cores, name='classify_descr', ) # # Write/Output files # c_labels = classifier.get_labels() def make_row(c): """ :type c: smqtk.representation.ClassificationElement """ c_m = c.get_classification() return [c.uuid] + [c_m[l] for l in c_labels] # column labels file log.info("Writing CSV column header file: %s", output_csv_header_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath)) with open(output_csv_header_filepath, 'wb') as f_csv: w = csv.writer(f_csv) w.writerow(['uuid'] + c_labels) # CSV file log.info("Writing CSV data file: %s", output_csv_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_filepath)) r_state = [0] * 7 with open(output_csv_filepath, 'wb') as f_csv: w = csv.writer(f_csv) for c in classification_iter: w.writerow(make_row(c)) bin_utils.report_progress(log.info, r_state, 1.0) # Final report r_state[1] -= 1 bin_utils.report_progress(log.info, r_state, 0) log.info("Done")
def classify_files(config, label, file_globs): log = logging.getLogger(__name__) #: :type: smqtk.algorithms.Classifier classifier = \ plugin.from_plugin_config(config['classifier'], get_classifier_impls()) def log_avaialable_labels(): log.info("Available classifier labels:") for l in classifier.get_labels(): log.info("- %s", l) if label is None: log_avaialable_labels() return elif label not in classifier.get_labels(): log.error( "Invalid classification label provided to compute and filter " "on: '%s'", label) log_avaialable_labels() return log.info("Collecting files from globs") #: :type: list[DataFileElement] data_elements = [] uuid2filepath = {} for g in file_globs: if os.path.isfile(g): d = DataFileElement(g) data_elements.append(d) uuid2filepath[d.uuid()] = g else: log.debug("expanding glob: %s", g) for fp in glob.iglob(g): d = DataFileElement(fp) data_elements.append(d) uuid2filepath[d.uuid()] = fp if not data_elements: raise RuntimeError("No files provided for classification.") log.info("Computing descriptors") descriptor_factory = \ DescriptorElementFactory.from_config(config['descriptor_factory']) #: :type: smqtk.algorithms.DescriptorGenerator descriptor_generator = \ plugin.from_plugin_config(config['descriptor_generator'], get_descriptor_generator_impls()) descr_map = descriptor_generator\ .compute_descriptor_async(data_elements, descriptor_factory) log.info("Classifying descriptors") classification_factory = ClassificationElementFactory \ .from_config(config['classification_factory']) classification_map = classifier\ .classify_async(descr_map.values(), classification_factory) log.info("Printing input file paths that classified as the given label.") # map of UUID to filepath: uuid2c = dict((c.uuid, c) for c in classification_map.itervalues()) for data in data_elements: if uuid2c[data.uuid()].max_label() == label: print uuid2filepath[data.uuid()]
def classify(self): """ Given a refined session ID and some number of descriptor UUIDs, create a classifier according to the current state and classify the given descriptors adjudicated. This will fail if the session has not been given adjudications (refined) yet. URI Args: sid UUID of the session to utilize uuids List of descriptor UUIDs to classify. Return list of results will be in the same order as this list. """ # Record clean/dirty status after making classifier/refining so we # don't train a new classifier when we don't have to. sid = flask.request.args.get('sid', None) uuids = flask.request.args.get('uuids', None) try: uuids = json.loads(uuids) except ValueError: return make_response_json( "Failed to decode uuids as json. Given '%s'" % uuids ), 400 if sid is None: return make_response_json("No session id (sid) provided"), 400 if not uuids: return make_response_json( "No descriptor UUIDs provided", sid=sid, ), 400 try: with self.controller.get_session(sid) as iqrs: if not iqrs.positive_descriptors: return make_response_json( "No positive labels in current session", sid=sid ), 400 if not iqrs.negative_descriptors: return make_response_json( "No negative labels in current session", sid=sid ), 400 # Get descriptor elements for classification try: descriptors = list(self.descriptor_index .get_many_descriptors(uuids)) except KeyError, ex: err_uuid = str(ex) self._log.warn(traceback.format_exc()) return make_response_json( "Descriptor UUID '%s' cannot be found in the " "configured descriptor index." % err_uuid, sid=sid, uuid=err_uuid, ), 404 classifier = self.session_classifiers.get(sid, None) pos_label = "positive" neg_label = "negative" if self.session_classifier_dirty[sid] or classifier is None: self._log.debug("Training new classifier for current " "refine state") #: :type: SupervisedClassifier classifier = plugin.from_plugin_config( self.classifier_config, get_classifier_impls(sub_interface=SupervisedClassifier) ) classifier.train( {pos_label: iqrs.positive_descriptors, neg_label: iqrs.negative_descriptors} ) self.session_classifiers[sid] = classifier self.session_classifier_dirty[sid] = False classifications = classifier.classify_async( descriptors, self.classification_factory, use_multiprocessing=True, ri=1.0 ) # Format output to be parallel lists of UUIDs input and # positive class classification scores. o_uuids = [] o_proba = [] for d in descriptors: o_uuids.append(d.uuid()) o_proba.append(classifications[d][pos_label]) assert uuids == o_uuids, \ "Output UUID list is not congruent with INPUT list." return make_response_json( "Finished classification", sid=sid, uuids=o_uuids, proba=o_proba, ), 200 except KeyError: return make_response_json("session id '%s' not found" % sid, sid=sid), 404
def main(): description = """ Utility for validating a given classifier implementation's model against some labeled testing data, outputting PR and ROC curve plots with area-under-curve score values. This utility can optionally be used train a supervised classifier model if the given classifier model configuration does not exist and a second CSV file listing labeled training data is provided. Training will be attempted if ``train`` is set to true. If training is performed, we exit after training completes. A ``SupervisedClassifier`` sub-classing implementation must be configured We expect the test and train CSV files in the column format: ... <UUID>,<label> ... The UUID is of the descriptor to which the label applies. The label may be any arbitrary string value, but all labels must be consistent in application. Some metrics presented assume the highest confidence class as the single predicted class for an element: - confusion matrix The output UUID confusion matrix is a JSON dictionary where the top-level keys are the true labels, and the inner dictionary is the mapping of predicted labels to the UUIDs of the classifications/descriptors that yielded the prediction. Again, this is based on the maximum probability label for a classification result (T=0.5). """ args, config = bin_utils.utility_main_helper(default_config, description) log = logging.getLogger(__name__) # # Initialize stuff from configuration # #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config(config['plugins']['classifier'], get_classifier_impls()) #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory']) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) uuid2label_filepath = config['utility']['csv_filepath'] do_train = config['utility']['train'] output_uuid_cm = config['utility']['output_uuid_confusion_matrix'] plot_filepath_pr = config['utility']['output_plot_pr'] plot_filepath_roc = config['utility']['output_plot_roc'] plot_filepath_cm = config['utility']['output_plot_confusion_matrix'] plot_ci = config['utility']['curve_confidence_interval'] plot_ci_alpha = config['utility']['curve_confidence_interval_alpha'] # # Construct mapping of label to the DescriptorElement instances for that # described by that label. # log.info("Loading descriptors by UUID") def iter_uuid_label(): """ Iterate through UUIDs in specified file """ with open(uuid2label_filepath) as uuid2label_file: reader = csv.reader(uuid2label_file) for r in reader: # TODO: This will need to be updated to handle multiple labels # per descriptor. yield r[0], r[1] def get_descr(r): """ Fetch descriptors from configured index """ uuid, truth_label = r return truth_label, descriptor_index.get_descriptor(uuid) tlabel_element_iter = parallel.parallel_map( get_descr, iter_uuid_label(), name="cmv_get_descriptors", use_multiprocessing=True, cores=config['parallelism']['descriptor_fetch_cores'], ) # Map of truth labels to descriptors of labeled data #: :type: dict[str, list[smqtk.representation.DescriptorElement]] tlabel2descriptors = {} for tlabel, d in tlabel_element_iter: tlabel2descriptors.setdefault(tlabel, []).append(d) # Train classifier if the one given has a ``train`` method and training # was turned enabled. if do_train: if isinstance(classifier, SupervisedClassifier): log.info("Training classifier model") classifier.train(tlabel2descriptors) exit(0) else: ValueError("Configured classifier is not a SupervisedClassifier " "type and does not support training.") # # Apply classifier to descriptors for predictions # # Truth label to predicted classification results #: :type: dict[str, set[smqtk.representation.ClassificationElement]] tlabel2classifications = {} for tlabel, descriptors in tlabel2descriptors.iteritems(): tlabel2classifications[tlabel] = \ set(classifier.classify_async( descriptors, classification_factory, use_multiprocessing=True, procs=config['parallelism']['classification_cores'], ri=1.0, ).values()) log.info("Truth label counts:") for l in sorted(tlabel2classifications): log.info(" %s :: %d", l, len(tlabel2classifications[l])) # # Confusion Matrix # conf_mat, labels = gen_confusion_matrix(tlabel2classifications) log.info("Confusion_matrix") log_cm(log.info, conf_mat, labels) if plot_filepath_cm: plot_cm(conf_mat, labels, plot_filepath_cm) # CM of descriptor UUIDs to output json if output_uuid_cm: # Top dictionary keys are true labels, inner dictionary keys are UUID # predicted labels. log.info("Computing UUID Confusion Matrix") #: :type: dict[str, dict[str, set | list]] uuid_cm = {} for tlabel in tlabel2classifications: uuid_cm[tlabel] = collections.defaultdict(set) for c in tlabel2classifications[tlabel]: uuid_cm[tlabel][c.max_label()].add(c.uuid) # convert sets to lists for plabel in uuid_cm[tlabel]: uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel]) with open(output_uuid_cm, 'w') as f: log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm) json.dump(uuid_cm, f, indent=2, separators=(',', ': ')) # # Create PR/ROC curves via scikit learn tools # if plot_filepath_pr: log.info("Making PR curve") make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci, plot_ci_alpha) if plot_filepath_roc: log.info("Making ROC curve") make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci, plot_ci_alpha)