def get_default_config(cls): c = super(SmqtkClassifierService, cls).get_default_config() c[cls.CONFIG_ENABLE_CLASSIFIER_REMOVAL] = False # Static classifier configurations c[cls.CONFIG_CLASSIFIER_COLLECTION] = \ ClassifierCollection.get_default_config() # Classification element factory for new classification results. c[cls.CONFIG_CLASSIFICATION_FACTORY] = \ ClassificationElementFactory.get_default_config() # Descriptor generator for new content c[cls.CONFIG_DESCRIPTOR_GENERATOR] = make_default_config( DescriptorGenerator.get_impls() ) # Descriptor factory for new content descriptors c[cls.CONFIG_DESCRIPTOR_FACTORY] = \ DescriptorElementFactory.get_default_config() # Optional Descriptor set for "included" descriptors referenceable by # UID. c[cls.CONFIG_DESCRIPTOR_SET] = make_default_config( DescriptorSet.get_impls() ) # from-IQR-state *supervised* classifier configuration c[cls.CONFIG_IQR_CLASSIFIER] = make_default_config( SupervisedClassifier.get_impls() ) c[cls.CONFIG_IMMUTABLE_LABELS] = [] return c
def get_default_config(cls): c = super(SmqtkClassifierService, cls).get_default_config() c[cls.CONFIG_ENABLE_CLASSIFIER_REMOVAL] = False # Static classifier configurations c[cls.CONFIG_CLASSIFIER_COLLECTION] = \ ClassifierCollection.get_default_config() # Classification element factory for new classification results. c[cls.CONFIG_CLASSIFICATION_FACTORY] = \ ClassificationElementFactory.get_default_config() # Descriptor generator for new content c[cls.CONFIG_DESCRIPTOR_GENERATOR] = smqtk.utils.plugin.make_config( get_descriptor_generator_impls() ) # Descriptor factory for new content descriptors c[cls.CONFIG_DESCRIPTOR_FACTORY] = \ DescriptorElementFactory.get_default_config() # from-IQR-state *supervised* classifier configuration c[cls.CONFIG_IQR_CLASSIFIER] = smqtk.utils.plugin.make_config( get_classifier_impls( sub_interface=SupervisedClassifier ) ) c[cls.CONFIG_IMMUTABLE_LABELS] = [] return c
def default_config(): return { 'plugins': { 'classifier': plugin.make_config(get_classifier_impls()), 'classification_factory': ClassificationElementFactory.get_default_config(), 'descriptor_index': plugin.make_config(get_descriptor_index_impls()) }, 'utility': { 'train': False, 'csv_filepath': 'CHAMGEME :: PATH :: a csv file', 'output_plot_pr': None, 'output_plot_roc': None, 'output_plot_confusion_matrix': None, 'output_uuid_confusion_matrix': None, 'curve_confidence_interval': False, 'curve_confidence_interval_alpha': 0.4, }, "parallelism": { "descriptor_fetch_cores": 4, "classification_cores": None, }, }
def test_classify_elements_none_preexisting(self): """ Test generating classification elements where none generated by the factory have existing vectors. i.e. all descriptor elements passed to underlying classification method.""" d_elems = [ DescriptorMemoryElement('', i).set_vector(v) for i, v in enumerate([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) ] # Mock a factory to produce elements whose ``has_classifications`` # method returns False. m_ce_type = mock.MagicMock(name="MockedClassificationElementType") c_factory = ClassificationElementFactory(m_ce_type, {}) # Mocking that elements have no classifications set m_ce_inst = m_ce_type.from_config() m_ce_inst.has_classifications.return_value = False list( self.inst.classify_elements(d_elems, factory=c_factory, overwrite=False)) assert m_ce_inst.has_classifications.call_count == 3 assert m_ce_inst.set_classification.call_count == 3 # Check that expected classification returns from dummy generator were # set to factory-created elements. m_ce_inst.set_classification.assert_any_call({'test': 1}) m_ce_inst.set_classification.assert_any_call({'test': 4}) m_ce_inst.set_classification.assert_any_call({'test': 7}) # Dummy classifier iterator completed to the end. self.inst._post_iterator_check.assert_called_once()
def test_classify_elements_all_preexisting(self): """ Test generating classification elements where all elements generated by the factory claim to already have classifications and overwrite is False.""" d_elems = [ DescriptorMemoryElement('', i).set_vector(v) for i, v in enumerate([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) ] # Mock a factory to produce elements whose ``has_classifications`` # method returns False. m_ce_type = mock.MagicMock(name="MockedClassificationElementType") c_factory = ClassificationElementFactory(m_ce_type, {}) # Mocking that elements have no classifications set m_ce_inst = m_ce_type.from_config() m_ce_inst.has_classifications.return_value = True list( self.inst.classify_elements(d_elems, factory=c_factory, overwrite=False)) assert m_ce_inst.has_classifications.call_count == 3 m_ce_inst.set_classification.assert_not_called() # Dummy classifier iterator completed to the end. self.inst._post_iterator_check.assert_called_once()
def test_classify_elements_all_preexisting_overwrite(self): """ Test generating classification elements where all elements generated by the factory claim to already have classifications but overwrite is True this time.""" d_elems = [ DescriptorMemoryElement('', i).set_vector(v) for i, v in enumerate([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) ] # Mock a factory to produce elements whose ``has_classifications`` # method returns False. m_ce_type = mock.MagicMock(name="MockedClassificationElementType") c_factory = ClassificationElementFactory(m_ce_type, {}) # Mocking that elements have no classifications set m_ce_inst = m_ce_type.from_config() m_ce_inst.has_classifications.return_value = True list( self.inst.classify_elements(d_elems, factory=c_factory, overwrite=True)) # Method not called becuase of overwrite short-circuit assert m_ce_inst.has_classifications.call_count == 0 assert m_ce_inst.set_classification.call_count == 3 # Check that expected classification returns from dummy generator were # set to factory-created elements. m_ce_inst.set_classification.assert_any_call({'test': 1}) m_ce_inst.set_classification.assert_any_call({'test': 4}) m_ce_inst.set_classification.assert_any_call({'test': 7}) # Dummy classifier iterator completed to the end. self.inst._post_iterator_check.assert_called_once()
def default_config(): return { "plugins": { "supervised_classifier": plugin.make_config(get_supervised_classifier_impls()), "descriptor_index": plugin.make_config(get_descriptor_index_impls()), "classification_factory": ClassificationElementFactory.get_default_config(), }, "cross_validation": { "truth_labels": None, "num_folds": 6, "random_seed": None, "classification_use_multiprocessing": True, }, "pr_curves": { "enabled": True, "show": False, "output_directory": None, "file_prefix": None, }, "roc_curves": { "enabled": True, "show": False, "output_directory": None, "file_prefix": None, }, }
def get_default_config(cls): c = super(IqrService, cls).get_default_config() c_rel_index = plugin.make_config( get_relevancy_index_impls() ) merge_dict(c_rel_index, iqr_session.DFLT_REL_INDEX_CONFIG) merge_dict(c, { "iqr_service": { "positive_seed_neighbors": 500, "plugin_notes": { "relevancy_index_config": "The relevancy index config provided should not have " "persistent storage configured as it will be used in " "such a way that instances are created, built and " "destroyed often.", "descriptor_index": "This is the index from which given positive and " "negative example descriptors are retrieved from. " "Not used for nearest neighbor querying. " "This index must contain all descriptors that could " "possibly be used as positive/negative examples and " "updated accordingly.", "neighbor_index": "This is the neighbor index to pull initial near-" "positive descriptors from.", "classifier_config": "The configuration to use for training and using " "classifiers for the /classifier endpoint. " "When configuring a classifier for use, don't fill " "out model persistence values as many classifiers " "may be created and thrown away during this service's " "operation.", "classification_factory": "Selection of the backend in which classifications " "are stored. The in-memory version is recommended " "because normal caching mechanisms will not account " "for the variety of classifiers that can potentially " "be created via this utility.", }, "plugins": { "relevancy_index_config": c_rel_index, "descriptor_index": plugin.make_config( get_descriptor_index_impls() ), "neighbor_index": plugin.make_config(get_nn_index_impls()), "classifier_config": plugin.make_config(get_classifier_impls()), "classification_factory": ClassificationElementFactory.get_default_config(), } } }) return c
def get_default_config(): return { "descriptor_factory": DescriptorElementFactory.get_default_config(), "descriptor_generator": plugin.make_config(get_descriptor_generator_impls), "classification_factory": ClassificationElementFactory.get_default_config(), "classifier": plugin.make_config(get_classifier_impls), }
def get_default_config(cls): c = super(IqrService, cls).get_default_config() c_rel_index = plugin.make_config(get_relevancy_index_impls()) merge_dict(c_rel_index, iqr_session.DFLT_REL_INDEX_CONFIG) merge_dict( c, { "iqr_service": { "positive_seed_neighbors": 500, "plugin_notes": { "relevancy_index_config": "The relevancy index config provided should not have " "persistent storage configured as it will be used in " "such a way that instances are created, built and " "destroyed often.", "descriptor_index": "This is the index from which given positive and " "negative example descriptors are retrieved from. " "Not used for nearest neighbor querying. " "This index must contain all descriptors that could " "possibly be used as positive/negative examples and " "updated accordingly.", "neighbor_index": "This is the neighbor index to pull initial near-" "positive descriptors from.", "classifier_config": "The configuration to use for training and using " "classifiers for the /classifier endpoint. " "When configuring a classifier for use, don't fill " "out model persistence values as many classifiers " "may be created and thrown away during this service's " "operation.", "classification_factory": "Selection of the backend in which classifications " "are stored. The in-memory version is recommended " "because normal caching mechanisms will not account " "for the variety of classifiers that can potentially " "be created via this utility.", }, "plugins": { "relevancy_index_config": c_rel_index, "descriptor_index": plugin.make_config(get_descriptor_index_impls()), "neighbor_index": plugin.make_config(get_nn_index_impls()), "classifier_config": plugin.make_config(get_classifier_impls()), "classification_factory": ClassificationElementFactory.get_default_config(), } } }) return c
def get_default_config(): return { "descriptor_factory": DescriptorElementFactory.get_default_config(), "descriptor_generator": make_default_config(DescriptorGenerator.get_impls()), "classification_factory": ClassificationElementFactory.get_default_config(), "classifier": make_default_config(Classifier.get_impls()), }
def get_default_config(): return { "descriptor_factory": DescriptorElementFactory.get_default_config(), "descriptor_generator": plugin.make_config(get_descriptor_generator_impls()), "classification_factory": ClassificationElementFactory.get_default_config(), "classifier": plugin.make_config(get_classifier_impls()), }
def __init__(self, json_config): super(SmqtkClassifierService, self).__init__(json_config) self.enable_classifier_removal = \ bool(json_config[self.CONFIG_ENABLE_CLASSIFIER_REMOVAL]) self.immutable_labels = set(json_config[self.CONFIG_IMMUTABLE_LABELS]) # Convert configuration into SMQTK plugin instances. # - Static classifier configurations. # - Skip the example config key # - Classification element factory # - Descriptor generator # - Descriptor element factory # - from-IQR-state classifier configuration # - There must at least be the default key defined for when no # specific classifier type is specified at state POST. # Classifier collection + factor self.classification_factory = \ ClassificationElementFactory.from_config( json_config[self.CONFIG_CLASSIFICATION_FACTORY] ) #: :type: ClassifierCollection self.classifier_collection = ClassifierCollection.from_config( json_config[self.CONFIG_CLASSIFIER_COLLECTION] ) # Descriptor generator + factory self.descriptor_factory = DescriptorElementFactory.from_config( json_config[self.CONFIG_DESCRIPTOR_FACTORY] ) #: :type: smqtk.algorithms.DescriptorGenerator self.descriptor_gen = from_config_dict( json_config[self.CONFIG_DESCRIPTOR_GENERATOR], smqtk.algorithms.DescriptorGenerator.get_impls() ) # Descriptor set bundled for classification-by-UID. try: self.descriptor_set = from_config_dict( json_config.get(self.CONFIG_DESCRIPTOR_SET, {}), DescriptorSet.get_impls() ) except ValueError: # Default empty set. self.descriptor_set = MemoryDescriptorSet() # Classifier config for uploaded IQR states. self.iqr_state_classifier_config = \ json_config[self.CONFIG_IQR_CLASSIFIER] self.add_routes()
def __init__(self, json_config): super(IqrService, self).__init__(json_config) sc_config = json_config['iqr_service']['session_control'] # Initialize from config self.positive_seed_neighbors = sc_config['positive_seed_neighbors'] self.classifier_config = \ json_config['iqr_service']['plugins']['classifier_config'] self.classification_factory = \ ClassificationElementFactory.from_config( json_config['iqr_service']['plugins']['classification_factory'] ) #: :type: smqtk.representation.DescriptorIndex self.descriptor_index = plugin.from_plugin_config( json_config['iqr_service']['plugins']['descriptor_index'], get_descriptor_index_impls(), ) #: :type: smqtk.algorithms.NearestNeighborsIndex self.neighbor_index = plugin.from_plugin_config( json_config['iqr_service']['plugins']['neighbor_index'], get_nn_index_impls(), ) self.rel_index_config = \ json_config['iqr_service']['plugins']['relevancy_index_config'] # Record of trained classifiers for a session. Session classifier # modifications locked under the parent session's global lock. #: :type: dict[collections.Hashable, smqtk.algorithms.SupervisedClassifier | None] self.session_classifiers = {} # Control for knowing when a new classifier should be trained for a # session (True == train new classifier). Modification for specific # sessions under parent session's lock. #: :type: dict[collections.Hashable, bool] self.session_classifier_dirty = {} def session_expire_callback(session): """ :type session: smqtk.iqr.IqrSession """ with session: self._log.debug("Removing session %s classifier", session.uuid) del self.session_classifiers[session.uuid] del self.session_classifier_dirty[session.uuid] self.controller = iqr_controller.IqrController( sc_config['session_expiration']['enabled'], sc_config['session_expiration']['check_interval_seconds'], session_expire_callback) self.session_timeout = \ sc_config['session_expiration']['session_timeout']
def __init__(self, json_config): super(SmqtkClassifierService, self).__init__(json_config) self.enable_classifier_removal = \ bool(json_config[self.CONFIG_ENABLE_CLASSIFIER_REMOVAL]) self.immutable_labels = set(json_config[self.CONFIG_IMMUTABLE_LABELS]) # Convert configuration into SMQTK plugin instances. # - Static classifier configurations. # - Skip the example config key # - Classification element factory # - Descriptor generator # - Descriptor element factory # - from-IQR-state classifier configuration # - There must at least be the default key defined for when no # specific classifier type is specified at state POST. # Classifier collection + factor self.classification_factory = \ ClassificationElementFactory.from_config( json_config[self.CONFIG_CLASSIFICATION_FACTORY] ) self.classifier_collection = ClassifierCollection.from_config( json_config[self.CONFIG_CLASSIFIER_COLLECTION] ) # Descriptor generator + factory self.descriptor_factory = DescriptorElementFactory.from_config( json_config[self.CONFIG_DESCRIPTOR_FACTORY] ) #: :type: smqtk.algorithms.DescriptorGenerator self.descriptor_gen = smqtk.utils.plugin.from_plugin_config( json_config[self.CONFIG_DESCRIPTOR_GENERATOR], smqtk.algorithms.get_descriptor_generator_impls() ) # Classifier config for uploaded IQR states. self.iqr_state_classifier_config = \ json_config[self.CONFIG_IQR_CLASSIFIER] self.add_routes()
def __init__(self, json_config): super(IqrService, self).__init__(json_config) # Initialize from config self.positive_seed_neighbors = \ json_config['iqr_service']['positive_seed_neighbors'] self.classifier_config = \ json_config['iqr_service']['plugins']['classifier_config'] self.classification_factory = \ ClassificationElementFactory.from_config( json_config['iqr_service']['plugins']['classification_factory'] ) #: :type: smqtk.representation.DescriptorIndex self.descriptor_index = plugin.from_plugin_config( json_config['iqr_service']['plugins']['descriptor_index'], get_descriptor_index_impls(), ) #: :type: smqtk.algorithms.NearestNeighborsIndex self.neighbor_index = plugin.from_plugin_config( json_config['iqr_service']['plugins']['neighbor_index'], get_nn_index_impls(), ) self.rel_index_config = \ json_config['iqr_service']['plugins']['relevancy_index_config'] self.controller = iqr_controller.IqrController() # Record of trained classifiers for a session. Session classifier # modifications locked under the parent session's global lock. #: :type: dict[collections.Hashable, smqtk.algorithms.SupervisedClassifier | None] self.session_classifiers = {} # Control for knowing when a new classifier should be trained for a # session (True == train new classifier). Modification for specific # sessions under parent session's lock. #: :type: dict[collections.Hashable, bool] self.session_classifier_dirty = {}
from smqtk.representation import ClassificationElementFactory from smqtk.representation.classification_element.memory import \ MemoryClassificationElement # Default classifier element factory for interfaces. DFLT_CLASSIFIER_FACTORY = ClassificationElementFactory( MemoryClassificationElement, {} )
def main(): description = """ Utility for validating a given classifier implementation's model against some labeled testing data, outputting PR and ROC curve plots with area-under-curve score values. This utility can optionally be used train a supervised classifier model if the given classifier model configuration does not exist and a second CSV file listing labeled training data is provided. Training will be attempted if ``train`` is set to true. If training is performed, we exit after training completes. A ``SupervisedClassifier`` sub-classing implementation must be configured We expect the test and train CSV files in the column format: ... <UUID>,<label> ... The UUID is of the descriptor to which the label applies. The label may be any arbitrary string value, but all labels must be consistent in application. Some metrics presented assume the highest confidence class as the single predicted class for an element: - confusion matrix The output UUID confusion matrix is a JSON dictionary where the top-level keys are the true labels, and the inner dictionary is the mapping of predicted labels to the UUIDs of the classifications/descriptors that yielded the prediction. Again, this is based on the maximum probability label for a classification result (T=0.5). """ args, config = bin_utils.utility_main_helper(default_config, description) log = logging.getLogger(__name__) # # Initialize stuff from configuration # #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config( config['plugins']['classifier'], get_classifier_impls() ) #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) uuid2label_filepath = config['utility']['csv_filepath'] do_train = config['utility']['train'] output_uuid_cm = config['utility']['output_uuid_confusion_matrix'] plot_filepath_pr = config['utility']['output_plot_pr'] plot_filepath_roc = config['utility']['output_plot_roc'] plot_filepath_cm = config['utility']['output_plot_confusion_matrix'] plot_ci = config['utility']['curve_confidence_interval'] plot_ci_alpha = config['utility']['curve_confidence_interval_alpha'] # # Construct mapping of label to the DescriptorElement instances for that # described by that label. # log.info("Loading descriptors by UUID") def iter_uuid_label(): """ Iterate through UUIDs in specified file """ with open(uuid2label_filepath) as uuid2label_file: reader = csv.reader(uuid2label_file) for r in reader: # TODO: This will need to be updated to handle multiple labels # per descriptor. yield r[0], r[1] def get_descr(r): """ Fetch descriptors from configured index """ uuid, truth_label = r return truth_label, descriptor_index.get_descriptor(uuid) tlabel_element_iter = parallel.parallel_map( get_descr, iter_uuid_label(), name="cmv_get_descriptors", use_multiprocessing=True, cores=config['parallelism']['descriptor_fetch_cores'], ) # Map of truth labels to descriptors of labeled data #: :type: dict[str, list[smqtk.representation.DescriptorElement]] tlabel2descriptors = {} for tlabel, d in tlabel_element_iter: tlabel2descriptors.setdefault(tlabel, []).append(d) # Train classifier if the one given has a ``train`` method and training # was turned enabled. if do_train: if isinstance(classifier, SupervisedClassifier): log.info("Training classifier model") classifier.train(tlabel2descriptors) exit(0) else: ValueError("Configured classifier is not a SupervisedClassifier " "type and does not support training.") # # Apply classifier to descriptors for predictions # # Truth label to predicted classification results #: :type: dict[str, set[smqtk.representation.ClassificationElement]] tlabel2classifications = {} for tlabel, descriptors in tlabel2descriptors.iteritems(): tlabel2classifications[tlabel] = \ set(classifier.classify_async( descriptors, classification_factory, use_multiprocessing=True, procs=config['parallelism']['classification_cores'], ri=1.0, ).values()) log.info("Truth label counts:") for l in sorted(tlabel2classifications): log.info(" %s :: %d", l, len(tlabel2classifications[l])) # # Confusion Matrix # conf_mat, labels = gen_confusion_matrix(tlabel2classifications) log.info("Confusion_matrix") log_cm(log.info, conf_mat, labels) if plot_filepath_cm: plot_cm(conf_mat, labels, plot_filepath_cm) # CM of descriptor UUIDs to output json if output_uuid_cm: # Top dictionary keys are true labels, inner dictionary keys are UUID # predicted labels. log.info("Computing UUID Confusion Matrix") #: :type: dict[str, dict[str, set | list]] uuid_cm = {} for tlabel in tlabel2classifications: uuid_cm[tlabel] = collections.defaultdict(set) for c in tlabel2classifications[tlabel]: uuid_cm[tlabel][c.max_label()].add(c.uuid) # convert sets to lists for plabel in uuid_cm[tlabel]: uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel]) with open(output_uuid_cm, 'w') as f: log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm) json.dump(uuid_cm, f, indent=2, separators=(',', ': ')) # # Create PR/ROC curves via scikit learn tools # if plot_filepath_pr: log.info("Making PR curve") make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci, plot_ci_alpha) if plot_filepath_roc: log.info("Making ROC curve") make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci, plot_ci_alpha)
def classifier_kfold_validation(): args = cli_parser().parse_args() config = cli.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # # Load configurations / Setup data # pr_enabled = config['pr_curves']['enabled'] pr_output_dir = config['pr_curves']['output_directory'] pr_file_prefix = config['pr_curves']['file_prefix'] or '' pr_show = config['pr_curves']['show'] roc_enabled = config['roc_curves']['enabled'] roc_output_dir = config['roc_curves']['output_directory'] roc_file_prefix = config['roc_curves']['file_prefix'] or '' roc_show = config['roc_curves']['show'] log.info("Initializing DescriptorSet (%s)", config['plugins']['descriptor_set']['type']) #: :type: smqtk.representation.DescriptorSet descriptor_set = from_config_dict(config['plugins']['descriptor_set'], DescriptorSet.get_impls()) log.info("Loading classifier configuration") #: :type: dict classifier_config = config['plugins']['supervised_classifier'] # Always use in-memory ClassificationElement since we are retraining the # classifier and don't want possible element caching #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory( MemoryClassificationElement, {}) log.info("Loading truth data") #: :type: list[str] uuids = [] #: :type: list[str] truth_labels = [] with open(config['cross_validation']['truth_labels']) as f: f_csv = csv.reader(f) for row in f_csv: uuids.append(row[0]) truth_labels.append(row[1]) #: :type: numpy.ndarray[str] uuids = numpy.array(uuids) #: :type: numpy.ndarray[str] truth_labels = numpy.array(truth_labels) # # Cross validation # kfolds = sklearn.model_selection.StratifiedKFold( n_splits=config['cross_validation']['num_folds'], shuffle=True, random_state=config['cross_validation']['random_seed'], ).split(numpy.zeros(len(truth_labels)), truth_labels) """ Truth and classification probability results for test data per fold. Format: { 0: { '<label>': { "truth": [...], # Parallel truth and classification "proba": [...], # probability values }, ... }, ... } """ fold_data: Dict[int, Any] = {} i = 0 for train, test in kfolds: log.info("Fold %d", i) log.info("-- %d training examples", len(train)) log.info("-- %d test examples", len(test)) fold_data[i] = {} log.info("-- creating classifier") classifier = cast( SupervisedClassifier, from_config_dict(classifier_config, SupervisedClassifier.get_impls())) log.info("-- gathering descriptors") pos_map: Dict[str, List[DescriptorElement]] = {} for idx in train: if truth_labels[idx] not in pos_map: pos_map[truth_labels[idx]] = [] pos_map[truth_labels[idx]].append( descriptor_set.get_descriptor(uuids[idx])) log.info("-- Training classifier") classifier.train(pos_map) log.info("-- Classifying test set") c_iter = classifier.classify_elements( (descriptor_set.get_descriptor(uuids[idx]) for idx in test), classification_factory, ) uuid2c = dict((c.uuid, c.get_classification()) for c in c_iter) log.info("-- Pairing truth and computed probabilities") # Only considering positive labels for t_label in pos_map: fold_data[i][t_label] = { "truth": [L == t_label for L in truth_labels[test]], "proba": [uuid2c[uuid][t_label] for uuid in uuids[test]] } i += 1 # # Curve generation # if pr_enabled: make_pr_curves(fold_data, pr_output_dir, pr_file_prefix, pr_show) if roc_enabled: make_roc_curves(fold_data, roc_output_dir, roc_file_prefix, roc_show)
def classifier_kfold_validation(): description = """ Helper utility for cross validating a supervised classifier configuration. The classifier used should NOT be configured to save its model since this process requires us to train the classifier multiple times. Configuration ------------- - plugins - supervised_classifier Supervised Classifier implementation configuration to use. This should not be set to use a persistent model if able. - descriptor_index Index to draw descriptors to classify from. - cross_validation - truth_labels Path to a CSV file containing descriptor UUID the truth label associations. This defines what descriptors are used from the given index. We error if any descriptor UUIDs listed here are not available in the given descriptor index. This file should be in [uuid, label] column format. - num_folds Number of folds to make for cross validation. - random_seed Optional fixed seed for the - classification_use_multiprocessing If we should use multiprocessing (vs threading) when classifying elements. - pr_curves - enabled If Precision/Recall plots should be generated. - show If we should attempt to show the graph after it has been generated (matplotlib). - output_directory Directory to save generated plots to. If None, we will not save plots. Otherwise we will create the directory (and required parent directories) if it does not exist. - file_prefix String prefix to prepend to standard plot file names. - roc_curves - enabled If ROC curves should be generated - show If we should attempt to show the plot after it has been generated (matplotlib). - output_directory Directory to save generated plots to. If None, we will not save plots. Otherwise we will create the directory (and required parent directories) if it does not exist. - file_prefix String prefix to prepend to standard plot file names. """ args, config = bin_utils.utility_main_helper(default_config, description) log = logging.getLogger(__name__) # # Load configurations / Setup data # use_mp = config['cross_validation']['classification_use_multiprocessing'] pr_enabled = config['pr_curves']['enabled'] pr_output_dir = config['pr_curves']['output_directory'] pr_file_prefix = config['pr_curves']['file_prefix'] or '' pr_show = config['pr_curves']['show'] roc_enabled = config['roc_curves']['enabled'] roc_output_dir = config['roc_curves']['output_directory'] roc_file_prefix = config['roc_curves']['file_prefix'] or '' roc_show = config['roc_curves']['show'] log.info("Initializing DescriptorIndex (%s)", config['plugins']['descriptor_index']['type']) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) log.info("Loading classifier configuration") #: :type: dict classifier_config = config['plugins']['supervised_classifier'] classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory']) log.info("Loading truth data") #: :type: list[str] uuids = [] #: :type: list[str] truth_labels = [] with open(config['cross_validation']['truth_labels']) as f: f_csv = csv.reader(f) for row in f_csv: uuids.append(row[0]) truth_labels.append(row[1]) #: :type: numpy.ndarray[str] uuids = numpy.array(uuids) #: :type: numpy.ndarray[str] truth_labels = numpy.array(truth_labels) # # Cross validation # kfolds = sklearn.cross_validation.StratifiedKFold( truth_labels, config['cross_validation']['num_folds'], random_state=config['cross_validation']['random_seed']) """ Truth and classification probability results for test data per fold. Format: { 0: { '<label>': { "truth": [...], # Parallel truth and classification "proba": [...], # probability values }, ... }, ... } """ fold_data = {} i = 0 for train, test in kfolds: log.info("Fold %d", i) log.info("-- %d training examples", len(train)) log.info("-- %d test examples", len(test)) fold_data[i] = {} log.info("-- creating classifier") #: :type: SupervisedClassifier classifier = plugin.from_plugin_config( classifier_config, get_supervised_classifier_impls()) log.info("-- gathering descriptors") #: :type: dict[str, list[smqtk.representation.DescriptorElement]] pos_map = {} for idx in train: if truth_labels[idx] not in pos_map: pos_map[truth_labels[idx]] = [] pos_map[truth_labels[idx]].append( descriptor_index.get_descriptor(uuids[idx])) log.info("-- Training classifier") classifier.train(pos_map) log.info("-- Classifying test set") m = classifier.classify_async( (descriptor_index.get_descriptor(uuids[idx]) for idx in test), classification_factory, use_multiprocessing=use_mp, ri=1.0) uuid2c = dict( (d.uuid(), c.get_classification()) for d, c in m.iteritems()) log.info("-- Pairing truth and computed probabilities") # Only considering positive labels for t_label in pos_map: fold_data[i][t_label] = { "truth": [l == t_label for l in truth_labels[test]], "proba": [uuid2c[uuid][t_label] for uuid in uuids[test]] } i += 1 # # Curve generation # if pr_enabled: make_pr_curves(fold_data, pr_output_dir, pr_file_prefix, pr_show) if roc_enabled: make_roc_curves(fold_data, roc_output_dir, roc_file_prefix, roc_show)
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # # Initialize stuff from configuration # #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config(config['plugins']['classifier'], get_classifier_impls()) #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory']) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) uuid2label_filepath = config['utility']['csv_filepath'] do_train = config['utility']['train'] output_uuid_cm = config['utility']['output_uuid_confusion_matrix'] plot_filepath_pr = config['utility']['output_plot_pr'] plot_filepath_roc = config['utility']['output_plot_roc'] plot_filepath_cm = config['utility']['output_plot_confusion_matrix'] plot_ci = config['utility']['curve_confidence_interval'] plot_ci_alpha = config['utility']['curve_confidence_interval_alpha'] # # Construct mapping of label to the DescriptorElement instances for that # described by that label. # log.info("Loading descriptors by UUID") def iter_uuid_label(): """ Iterate through UUIDs in specified file """ with open(uuid2label_filepath) as uuid2label_file: reader = csv.reader(uuid2label_file) for r in reader: # TODO: This will need to be updated to handle multiple labels # per descriptor. yield r[0], r[1] def get_descr(r): """ Fetch descriptors from configured index """ uuid, truth_label = r return truth_label, descriptor_index.get_descriptor(uuid) tlabel_element_iter = parallel.parallel_map( get_descr, iter_uuid_label(), name="cmv_get_descriptors", use_multiprocessing=True, cores=config['parallelism']['descriptor_fetch_cores'], ) # Map of truth labels to descriptors of labeled data #: :type: dict[str, list[smqtk.representation.DescriptorElement]] tlabel2descriptors = {} for tlabel, d in tlabel_element_iter: tlabel2descriptors.setdefault(tlabel, []).append(d) # Train classifier if the one given has a ``train`` method and training # was turned enabled. if do_train: if isinstance(classifier, SupervisedClassifier): log.info("Training classifier model") classifier.train(tlabel2descriptors) exit(0) else: ValueError("Configured classifier is not a SupervisedClassifier " "type and does not support training.") # # Apply classifier to descriptors for predictions # # Truth label to predicted classification results #: :type: dict[str, set[smqtk.representation.ClassificationElement]] tlabel2classifications = {} for tlabel, descriptors in tlabel2descriptors.items(): tlabel2classifications[tlabel] = \ set(classifier.classify_async( descriptors, classification_factory, use_multiprocessing=True, procs=config['parallelism']['classification_cores'], ri=1.0, ).values()) log.info("Truth label counts:") for l in sorted(tlabel2classifications): log.info(" %s :: %d", l, len(tlabel2classifications[l])) # # Confusion Matrix # conf_mat, labels = gen_confusion_matrix(tlabel2classifications) log.info("Confusion_matrix") log_cm(log.info, conf_mat, labels) if plot_filepath_cm: plot_cm(conf_mat, labels, plot_filepath_cm) # CM of descriptor UUIDs to output json if output_uuid_cm: # Top dictionary keys are true labels, inner dictionary keys are UUID # predicted labels. log.info("Computing UUID Confusion Matrix") #: :type: dict[str, dict[str, set | list]] uuid_cm = {} for tlabel in tlabel2classifications: uuid_cm[tlabel] = collections.defaultdict(set) for c in tlabel2classifications[tlabel]: uuid_cm[tlabel][c.max_label()].add(c.uuid) # convert sets to lists for plabel in uuid_cm[tlabel]: uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel]) with open(output_uuid_cm, 'w') as f: log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm) json.dump(uuid_cm, f, indent=2, separators=(',', ': ')) # # Create PR/ROC curves via scikit learn tools # if plot_filepath_pr: log.info("Making PR curve") make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci, plot_ci_alpha) if plot_filepath_roc: log.info("Making ROC curve") make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci, plot_ci_alpha)
def classifier_kfold_validation(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # # Load configurations / Setup data # use_mp = config['cross_validation']['classification_use_multiprocessing'] pr_enabled = config['pr_curves']['enabled'] pr_output_dir = config['pr_curves']['output_directory'] pr_file_prefix = config['pr_curves']['file_prefix'] or '' pr_show = config['pr_curves']['show'] roc_enabled = config['roc_curves']['enabled'] roc_output_dir = config['roc_curves']['output_directory'] roc_file_prefix = config['roc_curves']['file_prefix'] or '' roc_show = config['roc_curves']['show'] log.info("Initializing DescriptorIndex (%s)", config['plugins']['descriptor_index']['type']) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) log.info("Loading classifier configuration") #: :type: dict classifier_config = config['plugins']['supervised_classifier'] classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory']) log.info("Loading truth data") #: :type: list[str] uuids = [] #: :type: list[str] truth_labels = [] with open(config['cross_validation']['truth_labels']) as f: f_csv = csv.reader(f) for row in f_csv: uuids.append(row[0]) truth_labels.append(row[1]) #: :type: numpy.ndarray[str] uuids = numpy.array(uuids) #: :type: numpy.ndarray[str] truth_labels = numpy.array(truth_labels) # # Cross validation # kfolds = sklearn.cross_validation.StratifiedKFold( truth_labels, config['cross_validation']['num_folds'], random_state=config['cross_validation']['random_seed']) """ Truth and classification probability results for test data per fold. Format: { 0: { '<label>': { "truth": [...], # Parallel truth and classification "proba": [...], # probability values }, ... }, ... } """ fold_data = {} i = 0 for train, test in kfolds: log.info("Fold %d", i) log.info("-- %d training examples", len(train)) log.info("-- %d test examples", len(test)) fold_data[i] = {} log.info("-- creating classifier") #: :type: SupervisedClassifier classifier = plugin.from_plugin_config( classifier_config, get_supervised_classifier_impls()) log.info("-- gathering descriptors") #: :type: dict[str, list[smqtk.representation.DescriptorElement]] pos_map = {} for idx in train: if truth_labels[idx] not in pos_map: pos_map[truth_labels[idx]] = [] pos_map[truth_labels[idx]].append( descriptor_index.get_descriptor(uuids[idx])) log.info("-- Training classifier") classifier.train(pos_map) log.info("-- Classifying test set") m = classifier.classify_async( (descriptor_index.get_descriptor(uuids[idx]) for idx in test), classification_factory, use_multiprocessing=use_mp, ri=1.0) uuid2c = dict( (d.uuid(), c.get_classification()) for d, c in m.iteritems()) log.info("-- Pairing truth and computed probabilities") # Only considering positive labels for t_label in pos_map: fold_data[i][t_label] = { "truth": [l == t_label for l in truth_labels[test]], "proba": [uuid2c[uuid][t_label] for uuid in uuids[test]] } i += 1 # # Curve generation # if pr_enabled: make_pr_curves(fold_data, pr_output_dir, pr_file_prefix, pr_show) if roc_enabled: make_roc_curves(fold_data, roc_output_dir, roc_file_prefix, roc_show)
def test_simple_classification(self): """ simple LibSvmClassifier test - 2-class Test libSVM classification functionality using random constructed data, training the y=0.5 split """ DIM = 2 N = 1000 POS_LABEL = 'positive' NEG_LABEL = 'negative' p = multiprocessing.pool.ThreadPool() d_factory = DescriptorElementFactory(DescriptorMemoryElement, {}) c_factory = ClassificationElementFactory( MemoryClassificationElement, {}) def make_element(argtup): (i, v) = argtup d = d_factory.new_descriptor('test', i) d.set_vector(v) return d # Constructing artificial descriptors x = numpy.random.rand(N, DIM) x_pos = x[x[:, 1] <= 0.45] x_neg = x[x[:, 1] >= 0.55] d_pos = p.map(make_element, enumerate(x_pos)) d_neg = p.map(make_element, enumerate(x_neg, start=N // 2)) # Create/Train test classifier classifier = LibSvmClassifier( train_params={ '-t': 0, # linear kernel '-b': 1, # enable probability estimates '-c': 2, # SVM-C parameter C '-q': '', # quite mode }, normalize=None, # DO NOT normalize descriptors ) classifier.train({POS_LABEL: d_pos, NEG_LABEL: d_neg}) # Test classifier x = numpy.random.rand(N, DIM) x_pos = x[x[:, 1] <= 0.45] x_neg = x[x[:, 1] >= 0.55] d_pos = p.map(make_element, enumerate(x_pos, N)) d_neg = p.map(make_element, enumerate(x_neg, N + N // 2)) d_pos_sync = {} # for comparing to async for d in d_pos: c = classifier.classify(d, c_factory) ntools.assert_equal( c.max_label(), POS_LABEL, "Found False positive: %s :: %s" % (d.vector(), c.get_classification())) d_pos_sync[d] = c d_neg_sync = {} for d in d_neg: c = classifier.classify(d, c_factory) ntools.assert_equal( c.max_label(), NEG_LABEL, "Found False negative: %s :: %s" % (d.vector(), c.get_classification())) d_neg_sync[d] = c # test that async classify produces the same results # -- d_pos m_pos = classifier.classify_async(d_pos, c_factory) ntools.assert_equal( m_pos, d_pos_sync, "Async computation of pos set did not yield " "the same results as synchronous " "classification.") # -- d_neg m_neg = classifier.classify_async(d_neg, c_factory) ntools.assert_equal( m_neg, d_neg_sync, "Async computation of neg set did not yield " "the same results as synchronous " "classification.") # -- combined -- threaded combined_truth = dict(d_pos_sync.items()) combined_truth.update(d_neg_sync) m_combined = classifier.classify_async( d_pos + d_neg, c_factory, use_multiprocessing=False, ) ntools.assert_equal( m_combined, combined_truth, "Async computation of all test descriptors " "did not yield the same results as " "synchronous classification.") # -- combined -- multiprocess m_combined = classifier.classify_async( d_pos + d_neg, c_factory, use_multiprocessing=True, ) ntools.assert_equal( m_combined, combined_truth, "Async computation of all test descriptors " "(mixed order) did not yield the same results " "as synchronous classification.") # Closing resources p.close() p.join()
def test_simple_multiclass_classification(self): """ simple LibSvmClassifier test - 3-class Test libSVM classification functionality using random constructed data, training the y=0.33 and y=.66 split """ DIM = 2 N = 1000 P1_LABEL = 'p1' P2_LABEL = 'p2' P3_LABEL = 'p3' p = multiprocessing.pool.ThreadPool() d_factory = DescriptorElementFactory(DescriptorMemoryElement, {}) c_factory = ClassificationElementFactory( MemoryClassificationElement, {}) di = 0 def make_element(argtup): (i, v) = argtup d = d_factory.new_descriptor('test', i) d.set_vector(v) return d # Constructing artificial descriptors x = numpy.random.rand(N, DIM) x_p1 = x[x[:, 1] <= 0.30] x_p2 = x[(x[:, 1] >= 0.36) & (x[:, 1] <= 0.63)] x_p3 = x[x[:, 1] >= 0.69] d_p1 = p.map(make_element, enumerate(x_p1, di)) di += len(d_p1) d_p2 = p.map(make_element, enumerate(x_p2, di)) di += len(d_p2) d_p3 = p.map(make_element, enumerate(x_p3, di)) di += len(d_p3) # Create/Train test classifier classifier = LibSvmClassifier( train_params={ '-t': 0, # linear kernel '-b': 1, # enable probability estimates '-c': 2, # SVM-C parameter C '-q': '' # quite mode }, normalize=None, # DO NOT normalize descriptors ) classifier.train({P1_LABEL: d_p1, P2_LABEL: d_p2, P3_LABEL: d_p3}) # Test classifier x = numpy.random.rand(N, DIM) x_p1 = x[x[:, 1] <= 0.30] x_p2 = x[(x[:, 1] >= 0.36) & (x[:, 1] <= 0.63)] x_p3 = x[x[:, 1] >= 0.69] d_p1 = p.map(make_element, enumerate(x_p1, di)) di += len(d_p1) d_p2 = p.map(make_element, enumerate(x_p2, di)) di += len(d_p2) d_p3 = p.map(make_element, enumerate(x_p3, di)) di += len(d_p3) d_p1_sync = {} for d in d_p1: c = classifier.classify(d, c_factory) ntools.assert_equal( c.max_label(), P1_LABEL, "Incorrect %s label: %s :: %s" % (P1_LABEL, d.vector(), c.get_classification())) d_p1_sync[d] = c d_p2_sync = {} for d in d_p2: c = classifier.classify(d, c_factory) ntools.assert_equal( c.max_label(), P2_LABEL, "Incorrect %s label: %s :: %s" % (P2_LABEL, d.vector(), c.get_classification())) d_p2_sync[d] = c d_neg_sync = {} for d in d_p3: c = classifier.classify(d, c_factory) ntools.assert_equal( c.max_label(), P3_LABEL, "Incorrect %s label: %s :: %s" % (P3_LABEL, d.vector(), c.get_classification())) d_neg_sync[d] = c # test that async classify produces the same results # -- p1 async_p1 = classifier.classify_async(d_p1, c_factory) ntools.assert_equal( async_p1, d_p1_sync, "Async computation of p1 set did not yield " "the same results as synchronous computation.") # -- p2 async_p2 = classifier.classify_async(d_p2, c_factory) ntools.assert_equal( async_p2, d_p2_sync, "Async computation of p2 set did not yield " "the same results as synchronous computation.") # -- neg async_neg = classifier.classify_async(d_p3, c_factory) ntools.assert_equal( async_neg, d_neg_sync, "Async computation of neg set did not yield " "the same results as synchronous computation.") # -- combined -- threaded sync_combined = dict(d_p1_sync.items()) sync_combined.update(d_p2_sync) sync_combined.update(d_neg_sync) async_combined = classifier.classify_async( d_p1 + d_p2 + d_p3, c_factory, use_multiprocessing=False) ntools.assert_equal( async_combined, sync_combined, "Async computation of all test descriptors " "did not yield the same results as " "synchronous classification.") # -- combined -- multiprocess async_combined = classifier.classify_async( d_p1 + d_p2 + d_p3, c_factory, use_multiprocessing=True) ntools.assert_equal( async_combined, sync_combined, "Async computation of all test descriptors " "(mixed order) did not yield the same results " "as synchronous classification.") # Closing resources p.close() p.join()
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # - parallel_map UUIDs to load from the configured index # - classify iterated descriptors uuids_list_filepath = args.uuids_list output_csv_filepath = args.csv_data output_csv_header_filepath = args.csv_header classify_overwrite = config['utility']['classify_overwrite'] p_use_multiprocessing = \ config['utility']['parallel']['use_multiprocessing'] p_index_extraction_cores = \ config['utility']['parallel']['index_extraction_cores'] p_classification_cores = \ config['utility']['parallel']['classification_cores'] if not uuids_list_filepath: raise ValueError("No uuids_list_filepath specified.") elif not os.path.isfile(uuids_list_filepath): raise ValueError("Given uuids_list_filepath did not point to a file.") if output_csv_header_filepath is None: raise ValueError("Need a path to save CSV header labels") if output_csv_filepath is None: raise ValueError("Need a path to save CSV data.") # # Initialize configured plugins # log.info("Initializing descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) log.info("Initializing classification factory") c_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) log.info("Initializing classifier") #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config( config['plugins']['classifier'], get_classifier_impls() ) # # Setup/Process # def iter_uuids(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() def descr_for_uuid(uuid): """ :type uuid: collections.Hashable :rtype: smqtk.representation.DescriptorElement """ return descriptor_index.get_descriptor(uuid) def classify_descr(d): """ :type d: smqtk.representation.DescriptorElement :rtype: smqtk.representation.ClassificationElement """ return classifier.classify(d, c_factory, classify_overwrite) log.info("Initializing uuid-to-descriptor parallel map") #: :type: collections.Iterable[smqtk.representation.DescriptorElement] element_iter = parallel.parallel_map( descr_for_uuid, iter_uuids(), use_multiprocessing=p_use_multiprocessing, cores=p_index_extraction_cores, name="descr_for_uuid", ) log.info("Initializing descriptor-to-classification parallel map") #: :type: collections.Iterable[smqtk.representation.ClassificationElement] classification_iter = parallel.parallel_map( classify_descr, element_iter, use_multiprocessing=p_use_multiprocessing, cores=p_classification_cores, name='classify_descr', ) # # Write/Output files # c_labels = classifier.get_labels() def make_row(e): """ :type e: smqtk.representation.ClassificationElement """ c_m = e.get_classification() return [e.uuid] + [c_m[l] for l in c_labels] # column labels file log.info("Writing CSV column header file: %s", output_csv_header_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath)) with open(output_csv_header_filepath, 'wb') as f_csv: w = csv.writer(f_csv) w.writerow(['uuid'] + [str(cl) for cl in c_labels]) # CSV file log.info("Writing CSV data file: %s", output_csv_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_filepath)) r_state = [0] * 7 with open(output_csv_filepath, 'wb') as f_csv: w = csv.writer(f_csv) for c in classification_iter: w.writerow(make_row(c)) bin_utils.report_progress(log.info, r_state, 1.0) # Final report r_state[1] -= 1 bin_utils.report_progress(log.info, r_state, 0) log.info("Done")
def test_simple_multiclass_classification(self): """ Test libSVM classification functionality using random constructed data, training the y=0.33 and y=.66 split """ DIM = 2 N = 1000 P1_LABEL = 'p1' P2_LABEL = 'p2' p = multiprocessing.pool.ThreadPool() d_factory = DescriptorElementFactory(DescriptorMemoryElement, {}) c_factory = ClassificationElementFactory(MemoryClassificationElement, {}) di = 0 def make_element((i, v)): d = d_factory.new_descriptor('test', i) d.set_vector(v) return d # Constructing artificial descriptors x = numpy.random.rand(N, DIM) x_p1 = x[x[:, 1] <= 0.30] x_p2 = x[(x[:, 1] >= 0.36) & (x[:, 1] <= 0.63)] x_neg = x[x[:, 1] >= 0.69] d_p1 = p.map(make_element, enumerate(x_p1, di)) di += len(d_p1) d_p2 = p.map(make_element, enumerate(x_p2, di)) di += len(d_p2) d_neg = p.map(make_element, enumerate(x_neg, di)) di += len(d_neg) # Create/Train test classifier classifier = LibSvmClassifier( train_params={ '-t': 0, # linear kernel '-b': 1, # enable probability estimates '-c': 2, # SVM-C parameter C '-q': '' # quite mode }, normalize=None, # DO NOT normalize descriptors ) classifier.train({P1_LABEL: d_p1, P2_LABEL: d_p2}, d_neg) # Test classifier x = numpy.random.rand(N, DIM) x_p1 = x[x[:, 1] <= 0.30] x_p2 = x[(x[:, 1] >= 0.36) & (x[:, 1] <= 0.63)] x_neg = x[x[:, 1] >= 0.69] d_p1 = p.map(make_element, enumerate(x_p1, di)) di += len(d_p1) d_p2 = p.map(make_element, enumerate(x_p2, di)) di += len(d_p2) d_neg = p.map(make_element, enumerate(x_neg, di)) di += len(d_neg) for d in d_p1: c = classifier.classify(d, c_factory) ntools.assert_equal(c.max_label(), P1_LABEL, "Incorrect %s label: %s :: %s" % (P1_LABEL, d.vector(), c.get_classification())) for d in d_p2: c = classifier.classify(d, c_factory) ntools.assert_equal(c.max_label(), P2_LABEL, "Incorrect %s label: %s :: %s" % (P2_LABEL, d.vector(), c.get_classification())) for d in d_neg: c = classifier.classify(d, c_factory) ntools.assert_equal(c.max_label(), LibSvmClassifier.NEGATIVE_LABEL, "Incorrect %s label: %s :: %s" % (LibSvmClassifier.NEGATIVE_LABEL, d.vector(), c.get_classification())) # Closing resources p.close() p.join()
def test_simple_classification(self): """ Test libSVM classification functionality using random constructed data, training the y=0.5 split """ DIM = 2 N = 1000 POS_LABEL = 'positive' p = multiprocessing.pool.ThreadPool() d_factory = DescriptorElementFactory(DescriptorMemoryElement, {}) c_factory = ClassificationElementFactory(MemoryClassificationElement, {}) def make_element((i, v)): d = d_factory.new_descriptor('test', i) d.set_vector(v) return d # Constructing artificial descriptors x = numpy.random.rand(N, DIM) x_pos = x[x[:, 1] <= 0.45] x_neg = x[x[:, 1] >= 0.55] d_pos = p.map(make_element, enumerate(x_pos)) d_neg = p.map(make_element, enumerate(x_neg, start=N//2)) # Create/Train test classifier classifier = LibSvmClassifier( train_params={ '-t': 0, # linear kernel '-b': 1, # enable probability estimates '-c': 2, # SVM-C parameter C '-q': '', # quite mode }, normalize=None, # DO NOT normalize descriptors ) classifier.train({POS_LABEL: d_pos}, d_neg) # Test classifier x = numpy.random.rand(N, DIM) x_pos = x[x[:, 1] <= 0.45] x_neg = x[x[:, 1] >= 0.55] d_pos = p.map(make_element, enumerate(x_pos, N)) d_neg = p.map(make_element, enumerate(x_neg, N + N//2)) for d in d_pos: c = classifier.classify(d, c_factory) ntools.assert_equal(c.max_label(), POS_LABEL, "Found False positive: %s :: %s" % (d.vector(), c.get_classification())) for d in d_neg: c = classifier.classify(d, c_factory) ntools.assert_equal(c.max_label(), LibSvmClassifier.NEGATIVE_LABEL, "Found False negative: %s :: %s" % (d.vector(), c.get_classification())) # Closing resources p.close() p.join()
def main(): description = """ Utility for validating a given classifier implementation's model against some labeled testing data, outputting PR and ROC curve plots with area-under-curve score values. This utility can optionally be used train a supervised classifier model if the given classifier model configuration does not exist and a second CSV file listing labeled training data is provided. Training will be attempted if ``train`` is set to true. If training is performed, we exit after training completes. A ``SupervisedClassifier`` sub-classing implementation must be configured We expect the test and train CSV files in the column format: ... <UUID>,<label> ... The UUID is of the descriptor to which the label applies. The label may be any arbitrary string value, but all labels must be consistent in application. Some metrics presented assume the highest confidence class as the single predicted class for an element: - confusion matrix The output UUID confusion matrix is a JSON dictionary where the top-level keys are the true labels, and the inner dictionary is the mapping of predicted labels to the UUIDs of the classifications/descriptors that yielded the prediction. Again, this is based on the maximum probability label for a classification result (T=0.5). """ args, config = bin_utils.utility_main_helper(default_config, description) log = logging.getLogger(__name__) # # Initialize stuff from configuration # #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config(config['plugins']['classifier'], get_classifier_impls()) #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory']) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) uuid2label_filepath = config['utility']['csv_filepath'] do_train = config['utility']['train'] output_uuid_cm = config['utility']['output_uuid_confusion_matrix'] plot_filepath_pr = config['utility']['output_plot_pr'] plot_filepath_roc = config['utility']['output_plot_roc'] plot_filepath_cm = config['utility']['output_plot_confusion_matrix'] plot_ci = config['utility']['curve_confidence_interval'] plot_ci_alpha = config['utility']['curve_confidence_interval_alpha'] # # Construct mapping of label to the DescriptorElement instances for that # described by that label. # log.info("Loading descriptors by UUID") def iter_uuid_label(): """ Iterate through UUIDs in specified file """ with open(uuid2label_filepath) as uuid2label_file: reader = csv.reader(uuid2label_file) for r in reader: # TODO: This will need to be updated to handle multiple labels # per descriptor. yield r[0], r[1] def get_descr(r): """ Fetch descriptors from configured index """ uuid, truth_label = r return truth_label, descriptor_index.get_descriptor(uuid) tlabel_element_iter = parallel.parallel_map( get_descr, iter_uuid_label(), name="cmv_get_descriptors", use_multiprocessing=True, cores=config['parallelism']['descriptor_fetch_cores'], ) # Map of truth labels to descriptors of labeled data #: :type: dict[str, list[smqtk.representation.DescriptorElement]] tlabel2descriptors = {} for tlabel, d in tlabel_element_iter: tlabel2descriptors.setdefault(tlabel, []).append(d) # Train classifier if the one given has a ``train`` method and training # was turned enabled. if do_train: if isinstance(classifier, SupervisedClassifier): log.info("Training classifier model") classifier.train(tlabel2descriptors) exit(0) else: ValueError("Configured classifier is not a SupervisedClassifier " "type and does not support training.") # # Apply classifier to descriptors for predictions # # Truth label to predicted classification results #: :type: dict[str, set[smqtk.representation.ClassificationElement]] tlabel2classifications = {} for tlabel, descriptors in tlabel2descriptors.iteritems(): tlabel2classifications[tlabel] = \ set(classifier.classify_async( descriptors, classification_factory, use_multiprocessing=True, procs=config['parallelism']['classification_cores'], ri=1.0, ).values()) log.info("Truth label counts:") for l in sorted(tlabel2classifications): log.info(" %s :: %d", l, len(tlabel2classifications[l])) # # Confusion Matrix # conf_mat, labels = gen_confusion_matrix(tlabel2classifications) log.info("Confusion_matrix") log_cm(log.info, conf_mat, labels) if plot_filepath_cm: plot_cm(conf_mat, labels, plot_filepath_cm) # CM of descriptor UUIDs to output json if output_uuid_cm: # Top dictionary keys are true labels, inner dictionary keys are UUID # predicted labels. log.info("Computing UUID Confusion Matrix") #: :type: dict[str, dict[str, set | list]] uuid_cm = {} for tlabel in tlabel2classifications: uuid_cm[tlabel] = collections.defaultdict(set) for c in tlabel2classifications[tlabel]: uuid_cm[tlabel][c.max_label()].add(c.uuid) # convert sets to lists for plabel in uuid_cm[tlabel]: uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel]) with open(output_uuid_cm, 'w') as f: log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm) json.dump(uuid_cm, f, indent=2, separators=(',', ': ')) # # Create PR/ROC curves via scikit learn tools # if plot_filepath_pr: log.info("Making PR curve") make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci, plot_ci_alpha) if plot_filepath_roc: log.info("Making ROC curve") make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci, plot_ci_alpha)
def classifier_kfold_validation(): description = """ Helper utility for cross validating a supervised classifier configuration. The classifier used should NOT be configured to save its model since this process requires us to train the classifier multiple times. Configuration ------------- - plugins - supervised_classifier Supervised Classifier implementation configuration to use. This should not be set to use a persistent model if able. - descriptor_index Index to draw descriptors to classify from. - cross_validation - truth_labels Path to a CSV file containing descriptor UUID the truth label associations. This defines what descriptors are used from the given index. We error if any descriptor UUIDs listed here are not available in the given descriptor index. This file should be in [uuid, label] column format. - num_folds Number of folds to make for cross validation. - random_seed Optional fixed seed for the - classification_use_multiprocessing If we should use multiprocessing (vs threading) when classifying elements. - pr_curves - enabled If Precision/Recall plots should be generated. - show If we should attempt to show the graph after it has been generated (matplotlib). - output_directory Directory to save generated plots to. If None, we will not save plots. Otherwise we will create the directory (and required parent directories) if it does not exist. - file_prefix String prefix to prepend to standard plot file names. - roc_curves - enabled If ROC curves should be generated - show If we should attempt to show the plot after it has been generated (matplotlib). - output_directory Directory to save generated plots to. If None, we will not save plots. Otherwise we will create the directory (and required parent directories) if it does not exist. - file_prefix String prefix to prepend to standard plot file names. """ args, config = bin_utils.utility_main_helper(default_config, description) log = logging.getLogger(__name__) # # Load configurations / Setup data # use_mp = config['cross_validation']['classification_use_multiprocessing'] pr_enabled = config['pr_curves']['enabled'] pr_output_dir = config['pr_curves']['output_directory'] pr_file_prefix = config['pr_curves']['file_prefix'] or '' pr_show = config['pr_curves']['show'] roc_enabled = config['roc_curves']['enabled'] roc_output_dir = config['roc_curves']['output_directory'] roc_file_prefix = config['roc_curves']['file_prefix'] or '' roc_show = config['roc_curves']['show'] log.info("Initializing DescriptorIndex (%s)", config['plugins']['descriptor_index']['type']) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) log.info("Loading classifier configuration") #: :type: dict classifier_config = config['plugins']['supervised_classifier'] classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) log.info("Loading truth data") #: :type: list[str] uuids = [] #: :type: list[str] truth_labels = [] with open(config['cross_validation']['truth_labels']) as f: f_csv = csv.reader(f) for row in f_csv: uuids.append(row[0]) truth_labels.append(row[1]) #: :type: numpy.ndarray[str] uuids = numpy.array(uuids) #: :type: numpy.ndarray[str] truth_labels = numpy.array(truth_labels) # # Cross validation # kfolds = sklearn.cross_validation.StratifiedKFold( truth_labels, config['cross_validation']['num_folds'], random_state=config['cross_validation']['random_seed'] ) """ Truth and classification probability results for test data per fold. Format: { 0: { '<label>': { "truth": [...], # Parallel truth and classification "proba": [...], # probability values }, ... }, ... } """ fold_data = {} i = 0 for train, test in kfolds: log.info("Fold %d", i) log.info("-- %d training examples", len(train)) log.info("-- %d test examples", len(test)) fold_data[i] = {} log.info("-- creating classifier") #: :type: SupervisedClassifier classifier = plugin.from_plugin_config( classifier_config, get_supervised_classifier_impls() ) log.info("-- gathering descriptors") #: :type: dict[str, list[smqtk.representation.DescriptorElement]] pos_map = {} for idx in train: if truth_labels[idx] not in pos_map: pos_map[truth_labels[idx]] = [] pos_map[truth_labels[idx]].append( descriptor_index.get_descriptor(uuids[idx]) ) log.info("-- Training classifier") classifier.train(pos_map) log.info("-- Classifying test set") m = classifier.classify_async( (descriptor_index.get_descriptor(uuids[idx]) for idx in test), classification_factory, use_multiprocessing=use_mp, ri=1.0 ) uuid2c = dict((d.uuid(), c.get_classification()) for d, c in m.iteritems()) log.info("-- Pairing truth and computed probabilities") # Only considering positive labels for t_label in pos_map: fold_data[i][t_label] = { "truth": [l == t_label for l in truth_labels[test]], "proba": [uuid2c[uuid][t_label] for uuid in uuids[test]] } i += 1 # # Curve generation # if pr_enabled: make_pr_curves(fold_data, pr_output_dir, pr_file_prefix, pr_show) if roc_enabled: make_roc_curves(fold_data, roc_output_dir, roc_file_prefix, roc_show)
def test_no_save_model_pickle(self): # Test model preservation across pickling even without model cache # file paths set. classifier = LibSvmClassifier( train_params={ '-t': 0, # linear kernel '-b': 1, # enable probability estimates '-c': 2, # SVM-C parameter C '-q': '', # quite mode }, normalize=None, # DO NOT normalize descriptors ) ntools.assert_true(classifier.svm_model is None) # Empty model should not trigger __LOCAL__ content in pickle ntools.assert_not_in('__LOCAL__', classifier.__getstate__()) _ = cPickle.loads(cPickle.dumps(classifier)) # train arbitrary model (same as ``test_simple_classification``) DIM = 2 N = 1000 POS_LABEL = 'positive' NEG_LABEL = 'negative' d_factory = DescriptorElementFactory(DescriptorMemoryElement, {}) c_factory = ClassificationElementFactory( MemoryClassificationElement, {}) def make_element(argtup): (i, v) = argtup d = d_factory.new_descriptor('test', i) d.set_vector(v) return d # Constructing artificial descriptors x = numpy.random.rand(N, DIM) x_pos = x[x[:, 1] <= 0.45] x_neg = x[x[:, 1] >= 0.55] p = multiprocessing.pool.ThreadPool() d_pos = p.map(make_element, enumerate(x_pos)) d_neg = p.map(make_element, enumerate(x_neg, start=N // 2)) p.close() p.join() # Training classifier.train({POS_LABEL: d_pos, NEG_LABEL: d_neg}) # Test original classifier t_v = numpy.random.rand(DIM) t = d_factory.new_descriptor('query', 0) t.set_vector(t_v) c_expected = classifier.classify(t, c_factory) # Should see __LOCAL__ content in pickle state now p_state = classifier.__getstate__() ntools.assert_in('__LOCAL__', p_state) ntools.assert_in('__LOCAL_LABELS__', p_state) ntools.assert_in('__LOCAL_MODEL__', p_state) ntools.assert_true(len(p_state['__LOCAL_LABELS__']) > 0) ntools.assert_true(len(p_state['__LOCAL_MODEL__']) > 0) # Restored classifier should classify the same test descriptor the # same #: :type: LibSvmClassifier classifier2 = cPickle.loads(cPickle.dumps(classifier)) c_post_pickle = classifier2.classify(t, c_factory) # There may be floating point error, so extract actual confidence # values and check post round c_pp_positive = c_post_pickle[POS_LABEL] c_pp_negative = c_post_pickle[NEG_LABEL] c_e_positive = c_expected[POS_LABEL] c_e_negative = c_expected[NEG_LABEL] ntools.assert_almost_equal(c_e_positive, c_pp_positive, 5) ntools.assert_almost_equal(c_e_negative, c_pp_negative, 5)
def main(): description = """ Script for asynchronously computing classifications for DescriptorElements in a DescriptorIndex specified via a list of UUIDs. Results are output to a CSV file in the format: uuid, label1_confidence, label2_confidence, ... CSV columns labels are output to the given CSV header file path. Label columns will be in the order as reported by the classifier implementations ``get_labels`` method. Due to using an input file-list of UUIDs, we require that the UUIDs of indexed descriptors be strings, or equality comparable to the UUIDs' string representation. """ args, config = bin_utils.utility_main_helper( default_config, description, extend_parser, ) log = logging.getLogger(__name__) # - parallel_map UUIDs to load from the configured index # - classify iterated descriptors uuids_list_filepath = args.uuids_list output_csv_filepath = args.csv_data output_csv_header_filepath = args.csv_header classify_overwrite = config['utility']['classify_overwrite'] p_use_multiprocessing = \ config['utility']['parallel']['use_multiprocessing'] p_index_extraction_cores = \ config['utility']['parallel']['index_extraction_cores'] p_classification_cores = \ config['utility']['parallel']['classification_cores'] if not uuids_list_filepath: raise ValueError("No uuids_list_filepath specified.") elif not os.path.isfile(uuids_list_filepath): raise ValueError("Given uuids_list_filepath did not point to a file.") if output_csv_header_filepath is None: raise ValueError("Need a path to save CSV header labels") if output_csv_filepath is None: raise ValueError("Need a path to save CSV data.") # # Initialize configured plugins # log.info("Initializing descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) log.info("Initializing classification factory") c_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) log.info("Initializing classifier") #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config( config['plugins']['classifier'], get_classifier_impls() ) # # Setup/Process # def iter_uuids(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() def descr_for_uuid(uuid): """ :type uuid: collections.Hashable :rtype: smqtk.representation.DescriptorElement """ return descriptor_index.get_descriptor(uuid) def classify_descr(d): """ :type d: smqtk.representation.DescriptorElement :rtype: smqtk.representation.ClassificationElement """ return classifier.classify(d, c_factory, classify_overwrite) log.info("Initializing uuid-to-descriptor parallel map") #: :type: collections.Iterable[smqtk.representation.DescriptorElement] element_iter = parallel.parallel_map( descr_for_uuid, iter_uuids(), use_multiprocessing=p_use_multiprocessing, cores=p_index_extraction_cores, name="descr_for_uuid", ) log.info("Initializing descriptor-to-classification parallel map") #: :type: collections.Iterable[smqtk.representation.ClassificationElement] classification_iter = parallel.parallel_map( classify_descr, element_iter, use_multiprocessing=p_use_multiprocessing, cores=p_classification_cores, name='classify_descr', ) # # Write/Output files # c_labels = classifier.get_labels() def make_row(c): """ :type c: smqtk.representation.ClassificationElement """ c_m = c.get_classification() return [c.uuid] + [c_m[l] for l in c_labels] # column labels file log.info("Writing CSV column header file: %s", output_csv_header_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath)) with open(output_csv_header_filepath, 'wb') as f_csv: w = csv.writer(f_csv) w.writerow(['uuid'] + c_labels) # CSV file log.info("Writing CSV data file: %s", output_csv_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_filepath)) r_state = [0] * 7 with open(output_csv_filepath, 'wb') as f_csv: w = csv.writer(f_csv) for c in classification_iter: w.writerow(make_row(c)) bin_utils.report_progress(log.info, r_state, 1.0) # Final report r_state[1] -= 1 bin_utils.report_progress(log.info, r_state, 0) log.info("Done")
plt.ylabel("Precision") plt.legend(loc='best', fancybox=True, framealpha=0.5) plt.savefig(PLOT_PR_OUTPUT) else: # Using the final trained classifier with open(CLASSIFIER_TRAINING_CONFIG_JSON) as f: classifier_config = json.load(f) log.info("Loading plugins") descriptor_index = MemoryDescriptorIndex( file_cache=DESCRIPTOR_INDEX_FILE_CACHE) #: :type: smqtk.algorithms.Classifier classifier = from_plugin_config(classifier_config['plugins']['classifier'], get_classifier_impls()) c_factory = ClassificationElementFactory(MemoryClassificationElement, {}) #: :type: dict[str, list[str]] phone2shas = json.load(open(PHONE_SHA1_JSON)) #: :type: dict[str, float] phone2score = {} log.info("Classifying phone imagery descriptors") i = 0 descriptor_index_shas = set(descriptor_index.iterkeys()) for p in phone2shas: log.info('%s (%d / %d)', p, i + 1, len(phone2shas)) # Not all source "images" have descriptors since some URLs returned # non-image files. Intersect phone sha's with what was actually # computed. Warn if this reduces descriptors for classification to zero. indexed_shas = set(phone2shas[p]) & descriptor_index_shas
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # - parallel_map UUIDs to load from the configured index # - classify iterated descriptors uuids_list_filepath = args.uuids_list output_csv_filepath = args.csv_data output_csv_header_filepath = args.csv_header classify_overwrite = config['utility']['classify_overwrite'] p_use_multiprocessing = \ config['utility']['parallel']['use_multiprocessing'] p_index_extraction_cores = \ config['utility']['parallel']['index_extraction_cores'] p_classification_cores = \ config['utility']['parallel']['classification_cores'] if not uuids_list_filepath: raise ValueError("No uuids_list_filepath specified.") elif not os.path.isfile(uuids_list_filepath): raise ValueError("Given uuids_list_filepath did not point to a file.") if output_csv_header_filepath is None: raise ValueError("Need a path to save CSV header labels") if output_csv_filepath is None: raise ValueError("Need a path to save CSV data.") # # Initialize configured plugins # log.info("Initializing descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) log.info("Initializing classification factory") c_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory']) log.info("Initializing classifier") #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config(config['plugins']['classifier'], get_classifier_impls()) # # Setup/Process # def iter_uuids(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() def descr_for_uuid(uuid): """ :type uuid: collections.Hashable :rtype: smqtk.representation.DescriptorElement """ return descriptor_index.get_descriptor(uuid) def classify_descr(d): """ :type d: smqtk.representation.DescriptorElement :rtype: smqtk.representation.ClassificationElement """ return classifier.classify(d, c_factory, classify_overwrite) log.info("Initializing uuid-to-descriptor parallel map") #: :type: collections.Iterable[smqtk.representation.DescriptorElement] element_iter = parallel.parallel_map( descr_for_uuid, iter_uuids(), use_multiprocessing=p_use_multiprocessing, cores=p_index_extraction_cores, name="descr_for_uuid", ) log.info("Initializing descriptor-to-classification parallel map") #: :type: collections.Iterable[smqtk.representation.ClassificationElement] classification_iter = parallel.parallel_map( classify_descr, element_iter, use_multiprocessing=p_use_multiprocessing, cores=p_classification_cores, name='classify_descr', ) # # Write/Output files # c_labels = classifier.get_labels() def make_row(c): """ :type c: smqtk.representation.ClassificationElement """ c_m = c.get_classification() return [c.uuid] + [c_m[l] for l in c_labels] # column labels file log.info("Writing CSV column header file: %s", output_csv_header_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath)) with open(output_csv_header_filepath, 'wb') as f_csv: w = csv.writer(f_csv) w.writerow(['uuid'] + c_labels) # CSV file log.info("Writing CSV data file: %s", output_csv_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_filepath)) r_state = [0] * 7 with open(output_csv_filepath, 'wb') as f_csv: w = csv.writer(f_csv) for c in classification_iter: w.writerow(make_row(c)) bin_utils.report_progress(log.info, r_state, 1.0) # Final report r_state[1] -= 1 bin_utils.report_progress(log.info, r_state, 0) log.info("Done")
def main(): args = cli_parser().parse_args() config = cli.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # Deprecations if (config.get('parallelism', {}) .get('classification_cores', None) is not None): warnings.warn("Usage of 'classification_cores' is deprecated. " "Classifier parallelism is not defined on a " "per-implementation basis. See classifier " "implementation parameterization.", category=DeprecationWarning) # # Initialize stuff from configuration # #: :type: smqtk.algorithms.Classifier classifier = from_config_dict( config['plugins']['classifier'], SupervisedClassifier.get_impls() ) #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) #: :type: smqtk.representation.DescriptorSet descriptor_set = from_config_dict( config['plugins']['descriptor_set'], DescriptorSet.get_impls() ) uuid2label_filepath = config['utility']['csv_filepath'] do_train = config['utility']['train'] output_uuid_cm = config['utility']['output_uuid_confusion_matrix'] plot_filepath_pr = config['utility']['output_plot_pr'] plot_filepath_roc = config['utility']['output_plot_roc'] plot_filepath_cm = config['utility']['output_plot_confusion_matrix'] plot_ci = config['utility']['curve_confidence_interval'] plot_ci_alpha = config['utility']['curve_confidence_interval_alpha'] # # Construct mapping of label to the DescriptorElement instances for that # described by that label. # log.info("Loading descriptors by UUID") def iter_uuid_label(): """ Iterate through UUIDs in specified file """ with open(uuid2label_filepath) as uuid2label_file: reader = csv.reader(uuid2label_file) for r in reader: # TODO: This will need to be updated to handle multiple labels # per descriptor. yield r[0], r[1] def get_descr(r): """ Fetch descriptors from configured index """ uuid, truth_label = r return truth_label, descriptor_set.get_descriptor(uuid) tlabel_element_iter = parallel.parallel_map( get_descr, iter_uuid_label(), name="cmv_get_descriptors", use_multiprocessing=True, cores=config['parallelism']['descriptor_fetch_cores'], ) # Map of truth labels to descriptors of labeled data tlabel2descriptors: Dict[str, List[DescriptorElement]] = {} for tlabel, d in tlabel_element_iter: tlabel2descriptors.setdefault(tlabel, []).append(d) # Train classifier if the one given has a ``train`` method and training # was turned enabled. if do_train: log.info("Training supervised classifier model") classifier.train(tlabel2descriptors) exit(0) # # Apply classifier to descriptors for predictions # # Truth label to predicted classification results #: :type: dict[str, set[smqtk.representation.ClassificationElement]] tlabel2classifications = {} for tlabel, descriptors in six.iteritems(tlabel2descriptors): tlabel2classifications[tlabel] = \ set(classifier.classify_elements(descriptors, classification_factory)) log.info("Truth label counts:") for tlabel in sorted(tlabel2classifications): log.info(" %s :: %d", tlabel, len(tlabel2classifications[tlabel])) # # Confusion Matrix # conf_mat, labels = gen_confusion_matrix(tlabel2classifications) log.info("Confusion_matrix") log_cm(log.info, conf_mat, labels) if plot_filepath_cm: plot_cm(conf_mat, labels, plot_filepath_cm) # Confusion Matrix of descriptor UUIDs to output json if output_uuid_cm: # Top dictionary keys are true labels, inner dictionary keys are UUID # predicted labels. log.info("Computing UUID Confusion Matrix") uuid_cm: Dict[str, Dict[Hashable, Union[List, List]]] = {} for tlabel in tlabel2classifications: tlabel_uuid_cm = collections.defaultdict(set) for c in tlabel2classifications[tlabel]: tlabel_uuid_cm[c.max_label()].add(c.uuid) # convert sets to lists for master JSON output. uuid_cm[tlabel] = {} for plabel in tlabel_uuid_cm: uuid_cm[tlabel][plabel] = list(tlabel_uuid_cm[plabel]) with open(output_uuid_cm, 'w') as f: log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm) json.dump(uuid_cm, f, indent=2, separators=(',', ': ')) # # Create PR/ROC curves via scikit learn tools # if plot_filepath_pr: log.info("Making PR curve") make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci, plot_ci_alpha) if plot_filepath_roc: log.info("Making ROC curve") make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci, plot_ci_alpha)
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # # Initialize stuff from configuration # #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config( config['plugins']['classifier'], get_classifier_impls() ) #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) uuid2label_filepath = config['utility']['csv_filepath'] do_train = config['utility']['train'] output_uuid_cm = config['utility']['output_uuid_confusion_matrix'] plot_filepath_pr = config['utility']['output_plot_pr'] plot_filepath_roc = config['utility']['output_plot_roc'] plot_filepath_cm = config['utility']['output_plot_confusion_matrix'] plot_ci = config['utility']['curve_confidence_interval'] plot_ci_alpha = config['utility']['curve_confidence_interval_alpha'] # # Construct mapping of label to the DescriptorElement instances for that # described by that label. # log.info("Loading descriptors by UUID") def iter_uuid_label(): """ Iterate through UUIDs in specified file """ with open(uuid2label_filepath) as uuid2label_file: reader = csv.reader(uuid2label_file) for r in reader: # TODO: This will need to be updated to handle multiple labels # per descriptor. yield r[0], r[1] def get_descr(r): """ Fetch descriptors from configured index """ uuid, truth_label = r return truth_label, descriptor_index.get_descriptor(uuid) tlabel_element_iter = parallel.parallel_map( get_descr, iter_uuid_label(), name="cmv_get_descriptors", use_multiprocessing=True, cores=config['parallelism']['descriptor_fetch_cores'], ) # Map of truth labels to descriptors of labeled data #: :type: dict[str, list[smqtk.representation.DescriptorElement]] tlabel2descriptors = {} for tlabel, d in tlabel_element_iter: tlabel2descriptors.setdefault(tlabel, []).append(d) # Train classifier if the one given has a ``train`` method and training # was turned enabled. if do_train: if isinstance(classifier, SupervisedClassifier): log.info("Training classifier model") classifier.train(tlabel2descriptors) exit(0) else: ValueError("Configured classifier is not a SupervisedClassifier " "type and does not support training.") # # Apply classifier to descriptors for predictions # # Truth label to predicted classification results #: :type: dict[str, set[smqtk.representation.ClassificationElement]] tlabel2classifications = {} for tlabel, descriptors in six.iteritems(tlabel2descriptors): tlabel2classifications[tlabel] = \ set(classifier.classify_async( descriptors, classification_factory, use_multiprocessing=True, procs=config['parallelism']['classification_cores'], ri=1.0, ).values()) log.info("Truth label counts:") for l in sorted(tlabel2classifications): log.info(" %s :: %d", l, len(tlabel2classifications[l])) # # Confusion Matrix # conf_mat, labels = gen_confusion_matrix(tlabel2classifications) log.info("Confusion_matrix") log_cm(log.info, conf_mat, labels) if plot_filepath_cm: plot_cm(conf_mat, labels, plot_filepath_cm) # Confusion Matrix of descriptor UUIDs to output json if output_uuid_cm: # Top dictionary keys are true labels, inner dictionary keys are UUID # predicted labels. log.info("Computing UUID Confusion Matrix") #: :type: dict[str, dict[collections.Hashable, set]] uuid_cm = {} for tlabel in tlabel2classifications: uuid_cm[tlabel] = collections.defaultdict(set) for c in tlabel2classifications[tlabel]: uuid_cm[tlabel][c.max_label()].add(c.uuid) # convert sets to lists for plabel in uuid_cm[tlabel]: uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel]) with open(output_uuid_cm, 'w') as f: log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm) json.dump(uuid_cm, f, indent=2, separators=(',', ': ')) # # Create PR/ROC curves via scikit learn tools # if plot_filepath_pr: log.info("Making PR curve") make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci, plot_ci_alpha) if plot_filepath_roc: log.info("Making ROC curve") make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci, plot_ci_alpha)