Exemple #1
0
    def get_default_config(cls):
        c = super(SmqtkClassifierService, cls).get_default_config()

        c[cls.CONFIG_ENABLE_CLASSIFIER_REMOVAL] = False

        # Static classifier configurations
        c[cls.CONFIG_CLASSIFIER_COLLECTION] = \
            ClassifierCollection.get_default_config()
        # Classification element factory for new classification results.
        c[cls.CONFIG_CLASSIFICATION_FACTORY] = \
            ClassificationElementFactory.get_default_config()
        # Descriptor generator for new content
        c[cls.CONFIG_DESCRIPTOR_GENERATOR] = smqtk.utils.plugin.make_config(
            get_descriptor_generator_impls()
        )
        # Descriptor factory for new content descriptors
        c[cls.CONFIG_DESCRIPTOR_FACTORY] = \
            DescriptorElementFactory.get_default_config()
        # from-IQR-state *supervised* classifier configuration
        c[cls.CONFIG_IQR_CLASSIFIER] = smqtk.utils.plugin.make_config(
            get_classifier_impls(
                sub_interface=SupervisedClassifier
            )
        )
        c[cls.CONFIG_IMMUTABLE_LABELS] = []

        return c
def default_config():
    return {
        'plugins': {
            'classifier':
                plugin.make_config(get_classifier_impls()),
            'classification_factory':
                ClassificationElementFactory.get_default_config(),
            'descriptor_index':
                plugin.make_config(get_descriptor_index_impls())
        },
        'utility': {
            'train': False,
            'csv_filepath': 'CHAMGEME :: PATH :: a csv file',
            'output_plot_pr': None,
            'output_plot_roc': None,
            'output_plot_confusion_matrix': None,
            'output_uuid_confusion_matrix': None,
            'curve_confidence_interval': False,
            'curve_confidence_interval_alpha': 0.4,
        },
        "parallelism": {
            "descriptor_fetch_cores": 4,
            "classification_cores": None,
        },
    }
def default_config():
    return {
        'plugins': {
            'classifier':
            plugin.make_config(get_classifier_impls()),
            'classification_factory':
            ClassificationElementFactory.get_default_config(),
            'descriptor_index':
            plugin.make_config(get_descriptor_index_impls())
        },
        'utility': {
            'train': False,
            'csv_filepath': 'CHAMGEME :: PATH :: a csv file',
            'output_plot_pr': None,
            'output_plot_roc': None,
            'output_plot_confusion_matrix': None,
            'output_uuid_confusion_matrix': None,
            'curve_confidence_interval': False,
            'curve_confidence_interval_alpha': 0.4,
        },
        "parallelism": {
            "descriptor_fetch_cores": 4,
            "classification_cores": None,
        },
    }
Exemple #4
0
    def get_default_config(cls):
        c = super(IqrService, cls).get_default_config()

        c_rel_index = plugin.make_config(
            get_relevancy_index_impls()
        )
        merge_dict(c_rel_index, iqr_session.DFLT_REL_INDEX_CONFIG)

        merge_dict(c, {
            "iqr_service": {
                "positive_seed_neighbors": 500,

                "plugin_notes": {
                    "relevancy_index_config":
                        "The relevancy index config provided should not have "
                        "persistent storage configured as it will be used in "
                        "such a way that instances are created, built and "
                        "destroyed often.",
                    "descriptor_index":
                        "This is the index from which given positive and "
                        "negative example descriptors are retrieved from. "
                        "Not used for nearest neighbor querying. "
                        "This index must contain all descriptors that could "
                        "possibly be used as positive/negative examples and "
                        "updated accordingly.",
                    "neighbor_index":
                        "This is the neighbor index to pull initial near-"
                        "positive descriptors from.",
                    "classifier_config":
                        "The configuration to use for training and using "
                        "classifiers for the /classifier endpoint. "
                        "When configuring a classifier for use, don't fill "
                        "out model persistence values as many classifiers "
                        "may be created and thrown away during this service's "
                        "operation.",
                    "classification_factory":
                        "Selection of the backend in which classifications "
                        "are stored. The in-memory version is recommended "
                        "because normal caching mechanisms will not account "
                        "for the variety of classifiers that can potentially "
                        "be created via this utility.",
                },
                "plugins": {
                    "relevancy_index_config": c_rel_index,
                    "descriptor_index": plugin.make_config(
                        get_descriptor_index_impls()
                    ),
                    "neighbor_index":
                        plugin.make_config(get_nn_index_impls()),
                    "classifier_config":
                        plugin.make_config(get_classifier_impls()),
                    "classification_factory":
                        ClassificationElementFactory.get_default_config(),
                }
            }
        })
        return c
Exemple #5
0
    def get_default_config(cls):
        c = super(IqrService, cls).get_default_config()

        c_rel_index = plugin.make_config(get_relevancy_index_impls())
        merge_dict(c_rel_index, iqr_session.DFLT_REL_INDEX_CONFIG)

        merge_dict(
            c, {
                "iqr_service": {
                    "positive_seed_neighbors": 500,
                    "plugin_notes": {
                        "relevancy_index_config":
                        "The relevancy index config provided should not have "
                        "persistent storage configured as it will be used in "
                        "such a way that instances are created, built and "
                        "destroyed often.",
                        "descriptor_index":
                        "This is the index from which given positive and "
                        "negative example descriptors are retrieved from. "
                        "Not used for nearest neighbor querying. "
                        "This index must contain all descriptors that could "
                        "possibly be used as positive/negative examples and "
                        "updated accordingly.",
                        "neighbor_index":
                        "This is the neighbor index to pull initial near-"
                        "positive descriptors from.",
                        "classifier_config":
                        "The configuration to use for training and using "
                        "classifiers for the /classifier endpoint. "
                        "When configuring a classifier for use, don't fill "
                        "out model persistence values as many classifiers "
                        "may be created and thrown away during this service's "
                        "operation.",
                        "classification_factory":
                        "Selection of the backend in which classifications "
                        "are stored. The in-memory version is recommended "
                        "because normal caching mechanisms will not account "
                        "for the variety of classifiers that can potentially "
                        "be created via this utility.",
                    },
                    "plugins": {
                        "relevancy_index_config":
                        c_rel_index,
                        "descriptor_index":
                        plugin.make_config(get_descriptor_index_impls()),
                        "neighbor_index":
                        plugin.make_config(get_nn_index_impls()),
                        "classifier_config":
                        plugin.make_config(get_classifier_impls()),
                        "classification_factory":
                        ClassificationElementFactory.get_default_config(),
                    }
                }
            })
        return c
Exemple #6
0
def get_default_config():
    return {
        "descriptor_factory":
        DescriptorElementFactory.get_default_config(),
        "descriptor_generator":
        plugin.make_config(get_descriptor_generator_impls()),
        "classification_factory":
        ClassificationElementFactory.get_default_config(),
        "classifier":
        plugin.make_config(get_classifier_impls()),
    }
Exemple #7
0
def get_default_config():
    return {
        "descriptor_factory":
            DescriptorElementFactory.get_default_config(),
        "descriptor_generator":
            plugin.make_config(get_descriptor_generator_impls()),
        "classification_factory":
            ClassificationElementFactory.get_default_config(),
        "classifier":
            plugin.make_config(get_classifier_impls()),
    }
def default_config():
    return {
        "utility": {
            "classify_overwrite": False,
            "parallel": {
                "use_multiprocessing": False,
                "index_extraction_cores": None,
                "classification_cores": None,
            }
        },
        "plugins": {
            "classifier":
            plugin.make_config(get_classifier_impls()),
            "classification_factory":
            plugin.make_config(get_classification_element_impls()),
            "descriptor_index":
            plugin.make_config(get_descriptor_index_impls()),
        }
    }
def train_classifier_iqr(config, iqr_state_fp):
    #: :type: smqtk.algorithms.SupervisedClassifier
    classifier = from_plugin_config(
        config['classifier'],
        get_classifier_impls(sub_interface=SupervisedClassifier)
    )

    # Load state into an empty IqrSession instance.
    with open(iqr_state_fp, 'rb') as f:
        state_bytes = f.read().strip()
    descr_factory = DescriptorElementFactory(DescriptorMemoryElement, {})
    iqrs = IqrSession()
    iqrs.set_state_bytes(state_bytes, descr_factory)

    # Positive descriptor examples for training are composed of those from
    # external and internal sets. Same for negative descriptor examples.
    pos = iqrs.positive_descriptors | iqrs.external_positive_descriptors
    neg = iqrs.negative_descriptors | iqrs.external_negative_descriptors
    classifier.train(class_examples={'positive': pos, 'negative': neg})
def train_classifier_iqr(config, iqr_state_fp):
    log = logging.getLogger(__name__)

    #: :type: smqtk.algorithms.SupervisedClassifier
    classifier = from_plugin_config(config['classifier'],
                                    get_classifier_impls())

    if not isinstance(classifier, SupervisedClassifier):
        raise RuntimeError("Configured classifier must be of the "
                           "SupervisedClassifier type in order to train.")

    # Get pos/neg descriptors out of iqr state zip
    z_file = open(iqr_state_fp, 'r')
    z = zipfile.ZipFile(z_file)
    if len(z.namelist()) != 1:
        raise RuntimeError("Invalid IqrState file!")
    iqrs = json.loads(z.read(z.namelist()[0]))
    if len(iqrs) != 2:
        raise RuntimeError("Invalid IqrState file!")
    if 'pos' not in iqrs or 'neg' not in iqrs:
        raise RuntimeError("Invalid IqrState file!")

    log.info("Loading pos/neg descriptors")
    #: :type: list[smqtk.representation.DescriptorElement]
    pos = []
    #: :type: list[smqtk.representation.DescriptorElement]
    neg = []
    i = 0
    for v in set(map(tuple, iqrs['pos'])):
        d = DescriptorMemoryElement('train', i)
        d.set_vector(numpy.array(v))
        pos.append(d)
        i += 1
    for v in set(map(tuple, iqrs['neg'])):
        d = DescriptorMemoryElement('train', i)
        d.set_vector(numpy.array(v))
        neg.append(d)
        i += 1
    log.info('    positive -> %d', len(pos))
    log.info('    negative -> %d', len(neg))

    classifier.train(positive=pos, negative=neg)
def default_config():
    return {
        "utility": {
            "classify_overwrite": False,
            "parallel": {
                "use_multiprocessing": False,
                "index_extraction_cores": None,
                "classification_cores": None,
            }
        },
        "plugins": {
            "classifier": plugin.make_config(get_classifier_impls()),
            "classification_factory": plugin.make_config(
                get_classification_element_impls()
            ),
            "descriptor_index": plugin.make_config(
                get_descriptor_index_impls()
            ),
        }
    }
def get_supervised_classifier_impls():
    return get_classifier_impls(sub_interface=SupervisedClassifier)
def get_default_config():
    return {
        "classifier": make_config(get_classifier_impls()),
    }
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #
    # Initialize stuff from configuration
    #
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(config['plugins']['classifier'],
                                           get_classifier_impls())
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'], get_descriptor_index_impls())

    uuid2label_filepath = config['utility']['csv_filepath']
    do_train = config['utility']['train']
    output_uuid_cm = config['utility']['output_uuid_confusion_matrix']
    plot_filepath_pr = config['utility']['output_plot_pr']
    plot_filepath_roc = config['utility']['output_plot_roc']
    plot_filepath_cm = config['utility']['output_plot_confusion_matrix']
    plot_ci = config['utility']['curve_confidence_interval']
    plot_ci_alpha = config['utility']['curve_confidence_interval_alpha']

    #
    # Construct mapping of label to the DescriptorElement instances for that
    # described by that label.
    #
    log.info("Loading descriptors by UUID")

    def iter_uuid_label():
        """ Iterate through UUIDs in specified file """
        with open(uuid2label_filepath) as uuid2label_file:
            reader = csv.reader(uuid2label_file)
            for r in reader:
                # TODO: This will need to be updated to handle multiple labels
                #       per descriptor.
                yield r[0], r[1]

    def get_descr(r):
        """ Fetch descriptors from configured index """
        uuid, truth_label = r
        return truth_label, descriptor_index.get_descriptor(uuid)

    tlabel_element_iter = parallel.parallel_map(
        get_descr,
        iter_uuid_label(),
        name="cmv_get_descriptors",
        use_multiprocessing=True,
        cores=config['parallelism']['descriptor_fetch_cores'],
    )

    # Map of truth labels to descriptors of labeled data
    #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
    tlabel2descriptors = {}
    for tlabel, d in tlabel_element_iter:
        tlabel2descriptors.setdefault(tlabel, []).append(d)

    # Train classifier if the one given has a ``train`` method and training
    # was turned enabled.
    if do_train:
        if isinstance(classifier, SupervisedClassifier):
            log.info("Training classifier model")
            classifier.train(tlabel2descriptors)
            exit(0)
        else:
            ValueError("Configured classifier is not a SupervisedClassifier "
                       "type and does not support training.")

    #
    # Apply classifier to descriptors for predictions
    #

    # Truth label to predicted classification results
    #: :type: dict[str, set[smqtk.representation.ClassificationElement]]
    tlabel2classifications = {}
    for tlabel, descriptors in tlabel2descriptors.items():
        tlabel2classifications[tlabel] = \
            set(classifier.classify_async(
                descriptors, classification_factory,
                use_multiprocessing=True,
                procs=config['parallelism']['classification_cores'],
                ri=1.0,
            ).values())
    log.info("Truth label counts:")
    for l in sorted(tlabel2classifications):
        log.info("  %s :: %d", l, len(tlabel2classifications[l]))

    #
    # Confusion Matrix
    #
    conf_mat, labels = gen_confusion_matrix(tlabel2classifications)
    log.info("Confusion_matrix")
    log_cm(log.info, conf_mat, labels)
    if plot_filepath_cm:
        plot_cm(conf_mat, labels, plot_filepath_cm)

    # CM of descriptor UUIDs to output json
    if output_uuid_cm:
        # Top dictionary keys are true labels, inner dictionary keys are UUID
        # predicted labels.
        log.info("Computing UUID Confusion Matrix")
        #: :type: dict[str, dict[str, set | list]]
        uuid_cm = {}
        for tlabel in tlabel2classifications:
            uuid_cm[tlabel] = collections.defaultdict(set)
            for c in tlabel2classifications[tlabel]:
                uuid_cm[tlabel][c.max_label()].add(c.uuid)
            # convert sets to lists
            for plabel in uuid_cm[tlabel]:
                uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel])
        with open(output_uuid_cm, 'w') as f:
            log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm)
            json.dump(uuid_cm, f, indent=2, separators=(',', ': '))

    #
    # Create PR/ROC curves via scikit learn tools
    #
    if plot_filepath_pr:
        log.info("Making PR curve")
        make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci,
                       plot_ci_alpha)
    if plot_filepath_roc:
        log.info("Making ROC curve")
        make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci,
                        plot_ci_alpha)
Exemple #15
0
def classify_files(config, label, file_globs):
    log = logging.getLogger(__name__)

    #: :type: smqtk.algorithms.Classifier
    classifier = \
        plugin.from_plugin_config(config['classifier'],
                                  get_classifier_impls())

    def log_avaialable_labels():
        log.info("Available classifier labels:")
        for l in classifier.get_labels():
            log.info("- %s", l)

    if label is None:
        log_avaialable_labels()
        return
    elif label not in classifier.get_labels():
        log.error("Invalid classification label provided to compute and filter "
                  "on: '%s'", label)
        log_avaialable_labels()
        return

    log.info("Collecting files from globs")
    #: :type: list[DataFileElement]
    data_elements = []
    uuid2filepath = {}
    for g in file_globs:
        if os.path.isfile(g):
            d = DataFileElement(g)
            data_elements.append(d)
            uuid2filepath[d.uuid()] = g
        else:
            log.debug("expanding glob: %s", g)
            for fp in glob.iglob(g):
                d = DataFileElement(fp)
                data_elements.append(d)
                uuid2filepath[d.uuid()] = fp
    if not data_elements:
        raise RuntimeError("No files provided for classification.")

    log.info("Computing descriptors")
    descriptor_factory = \
        DescriptorElementFactory.from_config(config['descriptor_factory'])
    #: :type: smqtk.algorithms.DescriptorGenerator
    descriptor_generator = \
        plugin.from_plugin_config(config['descriptor_generator'],
                                  get_descriptor_generator_impls())
    descr_map = descriptor_generator\
        .compute_descriptor_async(data_elements, descriptor_factory)

    log.info("Classifying descriptors")
    classification_factory = ClassificationElementFactory \
        .from_config(config['classification_factory'])
    classification_map = classifier\
        .classify_async(list(descr_map.values()), classification_factory)

    log.info("Printing input file paths that classified as the given label.")
    # map of UUID to filepath:
    uuid2c = dict((c.uuid, c) for c in six.itervalues(classification_map))
    for data in data_elements:
        d_uuid = data.uuid()
        log.debug("'{}' classification map: {}".format(
            uuid2filepath[d_uuid], uuid2c[d_uuid].get_classification()
        ))
        if uuid2c[d_uuid].max_label() == label:
            print(uuid2filepath[d_uuid])
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    # - parallel_map UUIDs to load from the configured index
    # - classify iterated descriptors

    uuids_list_filepath = args.uuids_list
    output_csv_filepath = args.csv_data
    output_csv_header_filepath = args.csv_header
    classify_overwrite = config['utility']['classify_overwrite']

    p_use_multiprocessing = \
        config['utility']['parallel']['use_multiprocessing']
    p_index_extraction_cores = \
        config['utility']['parallel']['index_extraction_cores']
    p_classification_cores = \
        config['utility']['parallel']['classification_cores']

    if not uuids_list_filepath:
        raise ValueError("No uuids_list_filepath specified.")
    elif not os.path.isfile(uuids_list_filepath):
        raise ValueError("Given uuids_list_filepath did not point to a file.")
    if output_csv_header_filepath is None:
        raise ValueError("Need a path to save CSV header labels")
    if output_csv_filepath is None:
        raise ValueError("Need a path to save CSV data.")

    #
    # Initialize configured plugins
    #

    log.info("Initializing descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'], get_descriptor_index_impls())

    log.info("Initializing classification factory")
    c_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])

    log.info("Initializing classifier")
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(config['plugins']['classifier'],
                                           get_classifier_impls())

    #
    # Setup/Process
    #
    def iter_uuids():
        with open(uuids_list_filepath) as f:
            for l in f:
                yield l.strip()

    def descr_for_uuid(uuid):
        """
        :type uuid: collections.Hashable
        :rtype: smqtk.representation.DescriptorElement
        """
        return descriptor_index.get_descriptor(uuid)

    def classify_descr(d):
        """
        :type d: smqtk.representation.DescriptorElement
        :rtype: smqtk.representation.ClassificationElement
        """
        return classifier.classify(d, c_factory, classify_overwrite)

    log.info("Initializing uuid-to-descriptor parallel map")
    #: :type: collections.Iterable[smqtk.representation.DescriptorElement]
    element_iter = parallel.parallel_map(
        descr_for_uuid,
        iter_uuids(),
        use_multiprocessing=p_use_multiprocessing,
        cores=p_index_extraction_cores,
        name="descr_for_uuid",
    )

    log.info("Initializing descriptor-to-classification parallel map")
    #: :type: collections.Iterable[smqtk.representation.ClassificationElement]
    classification_iter = parallel.parallel_map(
        classify_descr,
        element_iter,
        use_multiprocessing=p_use_multiprocessing,
        cores=p_classification_cores,
        name='classify_descr',
    )

    #
    # Write/Output files
    #

    c_labels = classifier.get_labels()

    def make_row(c):
        """
        :type c: smqtk.representation.ClassificationElement
        """
        c_m = c.get_classification()
        return [c.uuid] + [c_m[l] for l in c_labels]

    # column labels file
    log.info("Writing CSV column header file: %s", output_csv_header_filepath)
    file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath))
    with open(output_csv_header_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        w.writerow(['uuid'] + c_labels)

    # CSV file
    log.info("Writing CSV data file: %s", output_csv_filepath)
    file_utils.safe_create_dir(os.path.dirname(output_csv_filepath))
    r_state = [0] * 7
    with open(output_csv_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        for c in classification_iter:
            w.writerow(make_row(c))
            bin_utils.report_progress(log.info, r_state, 1.0)

    # Final report
    r_state[1] -= 1
    bin_utils.report_progress(log.info, r_state, 0)

    log.info("Done")
Exemple #17
0
    plt.title("PR - HT Positive")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.legend(loc='best', fancybox=True, framealpha=0.5)
    plt.savefig(PLOT_PR_OUTPUT)

else:
    # Using the final trained classifier
    with open(CLASSIFIER_TRAINING_CONFIG_JSON) as f:
        classifier_config = json.load(f)

    log.info("Loading plugins")
    descriptor_index = MemoryDescriptorIndex(file_cache=DESCRIPTOR_INDEX_FILE_CACHE)
    #: :type: smqtk.algorithms.Classifier
    classifier = from_plugin_config(classifier_config['plugins']['classifier'],
                                    get_classifier_impls())
    c_factory = ClassificationElementFactory(MemoryClassificationElement, {})

    #: :type: dict[str, list[str]]
    phone2shas = json.load(open(PHONE_SHA1_JSON))
    #: :type: dict[str, float]
    phone2score = {}

    log.info("Classifying phone imagery descriptors")
    i = 0
    descriptor_index_shas = set(descriptor_index.iterkeys())
    for p in phone2shas:
        log.info('%s (%d / %d)', p, i + 1, len(phone2shas))
        # Not all source "images" have descriptors since some URLs returned
        # non-image files. Intersect phone sha's with what was actually
        # computed. Warn if this reduces descriptors for classification to zero.
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #
    # Initialize stuff from configuration
    #
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(
        config['plugins']['classifier'],
        get_classifier_impls()
    )
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory']
    )
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )

    uuid2label_filepath = config['utility']['csv_filepath']
    do_train = config['utility']['train']
    output_uuid_cm = config['utility']['output_uuid_confusion_matrix']
    plot_filepath_pr = config['utility']['output_plot_pr']
    plot_filepath_roc = config['utility']['output_plot_roc']
    plot_filepath_cm = config['utility']['output_plot_confusion_matrix']
    plot_ci = config['utility']['curve_confidence_interval']
    plot_ci_alpha = config['utility']['curve_confidence_interval_alpha']

    #
    # Construct mapping of label to the DescriptorElement instances for that
    # described by that label.
    #
    log.info("Loading descriptors by UUID")

    def iter_uuid_label():
        """ Iterate through UUIDs in specified file """
        with open(uuid2label_filepath) as uuid2label_file:
            reader = csv.reader(uuid2label_file)
            for r in reader:
                # TODO: This will need to be updated to handle multiple labels
                #       per descriptor.
                yield r[0], r[1]

    def get_descr(r):
        """ Fetch descriptors from configured index """
        uuid, truth_label = r
        return truth_label, descriptor_index.get_descriptor(uuid)

    tlabel_element_iter = parallel.parallel_map(
        get_descr, iter_uuid_label(),
        name="cmv_get_descriptors",
        use_multiprocessing=True,
        cores=config['parallelism']['descriptor_fetch_cores'],
    )

    # Map of truth labels to descriptors of labeled data
    #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
    tlabel2descriptors = {}
    for tlabel, d in tlabel_element_iter:
        tlabel2descriptors.setdefault(tlabel, []).append(d)

    # Train classifier if the one given has a ``train`` method and training
    # was turned enabled.
    if do_train:
        if isinstance(classifier, SupervisedClassifier):
            log.info("Training classifier model")
            classifier.train(tlabel2descriptors)
            exit(0)
        else:
            ValueError("Configured classifier is not a SupervisedClassifier "
                       "type and does not support training.")

    #
    # Apply classifier to descriptors for predictions
    #

    # Truth label to predicted classification results
    #: :type: dict[str, set[smqtk.representation.ClassificationElement]]
    tlabel2classifications = {}
    for tlabel, descriptors in six.iteritems(tlabel2descriptors):
        tlabel2classifications[tlabel] = \
            set(classifier.classify_async(
                descriptors, classification_factory,
                use_multiprocessing=True,
                procs=config['parallelism']['classification_cores'],
                ri=1.0,
            ).values())
    log.info("Truth label counts:")
    for l in sorted(tlabel2classifications):
        log.info("  %s :: %d", l, len(tlabel2classifications[l]))

    #
    # Confusion Matrix
    #
    conf_mat, labels = gen_confusion_matrix(tlabel2classifications)
    log.info("Confusion_matrix")
    log_cm(log.info, conf_mat, labels)
    if plot_filepath_cm:
        plot_cm(conf_mat, labels, plot_filepath_cm)

    # Confusion Matrix of descriptor UUIDs to output json
    if output_uuid_cm:
        # Top dictionary keys are true labels, inner dictionary keys are UUID
        # predicted labels.
        log.info("Computing UUID Confusion Matrix")
        #: :type: dict[str, dict[collections.Hashable, set]]
        uuid_cm = {}
        for tlabel in tlabel2classifications:
            uuid_cm[tlabel] = collections.defaultdict(set)
            for c in tlabel2classifications[tlabel]:
                uuid_cm[tlabel][c.max_label()].add(c.uuid)
            # convert sets to lists
            for plabel in uuid_cm[tlabel]:
                uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel])
        with open(output_uuid_cm, 'w') as f:
            log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm)
            json.dump(uuid_cm, f, indent=2, separators=(',', ': '))

    #
    # Create PR/ROC curves via scikit learn tools
    #
    if plot_filepath_pr:
        log.info("Making PR curve")
        make_pr_curves(tlabel2classifications, plot_filepath_pr,
                       plot_ci, plot_ci_alpha)
    if plot_filepath_roc:
        log.info("Making ROC curve")
        make_roc_curves(tlabel2classifications, plot_filepath_roc,
                        plot_ci, plot_ci_alpha)
Exemple #19
0
    def add_iqr_state_classifier(self):
        """
        Train a classifier based on the user-provided IQR state file bytes in
        a base64 encoding, matched with a descriptive label of that
        classifier's topic.

        Since all IQR session classifiers end up only having two result
        classes (positive and negative), the topic of the classifier is
        encoded in the descriptive label the user applies to the classifier.

        Below is an example call to this endpoint via the ``requests`` python
        module, showing how base64 data is sent::

            import base64
            import requests
            data_bytes = "Load some content bytes here."
            requests.get('http://localhost:5000/iqr_classifier',
                         data={'bytes_b64': base64.b64encode(data_bytes),
                               'label': 'some_label'})

        With curl on the command line::

            $ curl -X POST localhost:5000/iqr_classifier \
                -d "label=some_label" \
                --data-urlencode "bytes_b64=$(base64 -w0 /path/to/file)"

            # If this fails, you may wish to encode the file separately and
            # use the file reference syntax instead:

            $ base64 -w0 /path/to/file > /path/to/file.b64
            $ curl -X POST localhost:5000/iqr_classifier -d label=some_label \
                --data-urlencode bytes_64@/path/to/file.b64

        To lock this classifier and guard it against deletion, add
        "lock_label=true"::

            $ curl -X POST localhost:5000/iqr_classifier \
                -d "label=some_label" \
                -d "lock_label=true" \
                --data-urlencode "bytes_b64=$(base64 -w0 /path/to/file)"

        Form arguments:
            iqr_state_b64
                base64 encoding of the bytes of the IQR session state save
                file.
            label
                Descriptive label to apply to this classifier. This should not
                conflict with existing classifier labels.
            lock_label
                If 'true', disallow deletion of this label. If 'false', allow
                deletion of this label. Only has an effect if deletion is
                enabled for this service. (Default: 'false')

        Returns 201.

        """
        data_b64 = flask.request.values.get('bytes_b64', default=None)
        label = flask.request.values.get('label', default=None)
        lock_clfr_str = flask.request.values.get('lock_label',
                                                 default='false')

        if data_b64 is None or len(data_b64) == 0:
            return make_response_json("No state base64 data provided.", 400)
        elif label is None or len(label) == 0:
            return make_response_json("No descriptive label provided.", 400)
        try:
            lock_clfr = bool(flask.json.loads(lock_clfr_str))
        except JSON_DECODE_EXCEPTION:
            return make_response_json("Invalid boolean value for"
                                      " 'lock_label'. Was given: '%s'"
                                      % lock_clfr_str,
                                      400)
        try:
            # Using urlsafe version because it handles both regular and urlsafe
            # alphabets.
            data_bytes = base64.urlsafe_b64decode(data_b64.encode('utf-8'))
        except (TypeError, binascii.Error) as ex:
            return make_response_json("Invalid base64 input: %s" % str(ex)), \
                   400

        # If the given label conflicts with one already in the collection,
        # fail.
        if label in self.classifier_collection.labels():
            return make_response_json(
                "Label already exists in classifier collection.", 400)

        # Create dummy IqrSession to extract pos/neg descriptors.
        iqrs = IqrSession()
        iqrs.set_state_bytes(data_bytes, self.descriptor_factory)
        pos = iqrs.positive_descriptors | iqrs.external_positive_descriptors
        neg = iqrs.negative_descriptors | iqrs.external_negative_descriptors
        del iqrs

        # Make a classifier instance from the stored config for IQR
        # session-based classifiers.
        #: :type: SupervisedClassifier
        classifier = smqtk.utils.plugin.from_plugin_config(
            self.iqr_state_classifier_config,
            get_classifier_impls(sub_interface=SupervisedClassifier)
        )
        classifier.train(class_examples={'positive': pos, 'negative': neg})

        try:
            self.classifier_collection.add_classifier(label, classifier)

            # If we're allowing deletions, get the lock flag from the form and
            # set it for this classifier
            if self.enable_classifier_removal and lock_clfr:
                self.immutable_labels.add(label)

        except ValueError as e:
            if e.args[0].find('JSON') > -1:
                return make_response_json("Tried to parse malformed JSON in "
                                          "form argument.", 400)
            return make_response_json("Duplicate label ('%s') added during "
                                      "classifier training of provided IQR "
                                      "session state." % label, 400,
                                      label=label)

        return make_response_json("Finished training IQR-session-based "
                                  "classifier for label '%s'." % label,
                                  201,
                                  label=label)
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    # - parallel_map UUIDs to load from the configured index
    # - classify iterated descriptors

    uuids_list_filepath = args.uuids_list
    output_csv_filepath = args.csv_data
    output_csv_header_filepath = args.csv_header
    classify_overwrite = config['utility']['classify_overwrite']

    p_use_multiprocessing = \
        config['utility']['parallel']['use_multiprocessing']
    p_index_extraction_cores = \
        config['utility']['parallel']['index_extraction_cores']
    p_classification_cores = \
        config['utility']['parallel']['classification_cores']

    if not uuids_list_filepath:
        raise ValueError("No uuids_list_filepath specified.")
    elif not os.path.isfile(uuids_list_filepath):
        raise ValueError("Given uuids_list_filepath did not point to a file.")
    if output_csv_header_filepath is None:
        raise ValueError("Need a path to save CSV header labels")
    if output_csv_filepath is None:
        raise ValueError("Need a path to save CSV data.")

    #
    # Initialize configured plugins
    #

    log.info("Initializing descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )

    log.info("Initializing classification factory")
    c_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory']
    )

    log.info("Initializing classifier")
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(
        config['plugins']['classifier'], get_classifier_impls()
    )

    #
    # Setup/Process
    #
    def iter_uuids():
        with open(uuids_list_filepath) as f:
            for l in f:
                yield l.strip()

    def descr_for_uuid(uuid):
        """
        :type uuid: collections.Hashable
        :rtype: smqtk.representation.DescriptorElement
        """
        return descriptor_index.get_descriptor(uuid)

    def classify_descr(d):
        """
        :type d: smqtk.representation.DescriptorElement
        :rtype: smqtk.representation.ClassificationElement
        """
        return classifier.classify(d, c_factory, classify_overwrite)

    log.info("Initializing uuid-to-descriptor parallel map")
    #: :type: collections.Iterable[smqtk.representation.DescriptorElement]
    element_iter = parallel.parallel_map(
        descr_for_uuid, iter_uuids(),
        use_multiprocessing=p_use_multiprocessing,
        cores=p_index_extraction_cores,
        name="descr_for_uuid",
    )

    log.info("Initializing descriptor-to-classification parallel map")
    #: :type: collections.Iterable[smqtk.representation.ClassificationElement]
    classification_iter = parallel.parallel_map(
        classify_descr, element_iter,
        use_multiprocessing=p_use_multiprocessing,
        cores=p_classification_cores,
        name='classify_descr',
    )

    #
    # Write/Output files
    #

    c_labels = classifier.get_labels()

    def make_row(e):
        """
        :type e: smqtk.representation.ClassificationElement
        """
        c_m = e.get_classification()
        return [e.uuid] + [c_m[l] for l in c_labels]

    # column labels file
    log.info("Writing CSV column header file: %s", output_csv_header_filepath)
    file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath))
    with open(output_csv_header_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        w.writerow(['uuid'] + [str(cl) for cl in c_labels])

    # CSV file
    log.info("Writing CSV data file: %s", output_csv_filepath)
    file_utils.safe_create_dir(os.path.dirname(output_csv_filepath))
    r_state = [0] * 7
    with open(output_csv_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        for c in classification_iter:
            w.writerow(make_row(c))
            bin_utils.report_progress(log.info, r_state, 1.0)

    # Final report
    r_state[1] -= 1
    bin_utils.report_progress(log.info, r_state, 0)

    log.info("Done")
def main():
    description = """
    Utility for validating a given classifier implementation's model against
    some labeled testing data, outputting PR and ROC curve plots with
    area-under-curve score values.

    This utility can optionally be used train a supervised classifier model if
    the given classifier model configuration does not exist and a second CSV
    file listing labeled training data is provided. Training will be attempted
    if ``train`` is set to true. If training is performed, we exit after
    training completes. A ``SupervisedClassifier`` sub-classing implementation
    must be configured

    We expect the test and train CSV files in the column format:

        ...
        <UUID>,<label>
        ...

    The UUID is of the descriptor to which the label applies. The label may be
    any arbitrary string value, but all labels must be consistent in
    application.

    Some metrics presented assume the highest confidence class as the single
    predicted class for an element:

        - confusion matrix

    The output UUID confusion matrix is a JSON dictionary where the top-level
    keys are the true labels, and the inner dictionary is the mapping of
    predicted labels to the UUIDs of the classifications/descriptors that
    yielded the prediction. Again, this is based on the maximum probability
    label for a classification result (T=0.5).
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    #
    # Initialize stuff from configuration
    #
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(
        config['plugins']['classifier'],
        get_classifier_impls()
    )
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory']
    )
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )

    uuid2label_filepath = config['utility']['csv_filepath']
    do_train = config['utility']['train']
    output_uuid_cm = config['utility']['output_uuid_confusion_matrix']
    plot_filepath_pr = config['utility']['output_plot_pr']
    plot_filepath_roc = config['utility']['output_plot_roc']
    plot_filepath_cm = config['utility']['output_plot_confusion_matrix']
    plot_ci = config['utility']['curve_confidence_interval']
    plot_ci_alpha = config['utility']['curve_confidence_interval_alpha']

    #
    # Construct mapping of label to the DescriptorElement instances for that
    # described by that label.
    #
    log.info("Loading descriptors by UUID")

    def iter_uuid_label():
        """ Iterate through UUIDs in specified file """
        with open(uuid2label_filepath) as uuid2label_file:
            reader = csv.reader(uuid2label_file)
            for r in reader:
                # TODO: This will need to be updated to handle multiple labels
                #       per descriptor.
                yield r[0], r[1]

    def get_descr(r):
        """ Fetch descriptors from configured index """
        uuid, truth_label = r
        return truth_label, descriptor_index.get_descriptor(uuid)

    tlabel_element_iter = parallel.parallel_map(
        get_descr, iter_uuid_label(),
        name="cmv_get_descriptors",
        use_multiprocessing=True,
        cores=config['parallelism']['descriptor_fetch_cores'],
    )

    # Map of truth labels to descriptors of labeled data
    #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
    tlabel2descriptors = {}
    for tlabel, d in tlabel_element_iter:
        tlabel2descriptors.setdefault(tlabel, []).append(d)

    # Train classifier if the one given has a ``train`` method and training
    # was turned enabled.
    if do_train:
        if isinstance(classifier, SupervisedClassifier):
            log.info("Training classifier model")
            classifier.train(tlabel2descriptors)
            exit(0)
        else:
            ValueError("Configured classifier is not a SupervisedClassifier "
                       "type and does not support training.")

    #
    # Apply classifier to descriptors for predictions
    #

    # Truth label to predicted classification results
    #: :type: dict[str, set[smqtk.representation.ClassificationElement]]
    tlabel2classifications = {}
    for tlabel, descriptors in tlabel2descriptors.iteritems():
        tlabel2classifications[tlabel] = \
            set(classifier.classify_async(
                descriptors, classification_factory,
                use_multiprocessing=True,
                procs=config['parallelism']['classification_cores'],
                ri=1.0,
            ).values())
    log.info("Truth label counts:")
    for l in sorted(tlabel2classifications):
        log.info("  %s :: %d", l, len(tlabel2classifications[l]))

    #
    # Confusion Matrix
    #
    conf_mat, labels = gen_confusion_matrix(tlabel2classifications)
    log.info("Confusion_matrix")
    log_cm(log.info, conf_mat, labels)
    if plot_filepath_cm:
        plot_cm(conf_mat, labels, plot_filepath_cm)

    # CM of descriptor UUIDs to output json
    if output_uuid_cm:
        # Top dictionary keys are true labels, inner dictionary keys are UUID
        # predicted labels.
        log.info("Computing UUID Confusion Matrix")
        #: :type: dict[str, dict[str, set | list]]
        uuid_cm = {}
        for tlabel in tlabel2classifications:
            uuid_cm[tlabel] = collections.defaultdict(set)
            for c in tlabel2classifications[tlabel]:
                uuid_cm[tlabel][c.max_label()].add(c.uuid)
            # convert sets to lists
            for plabel in uuid_cm[tlabel]:
                uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel])
        with open(output_uuid_cm, 'w') as f:
            log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm)
            json.dump(uuid_cm, f, indent=2, separators=(',', ': '))


    #
    # Create PR/ROC curves via scikit learn tools
    #
    if plot_filepath_pr:
        log.info("Making PR curve")
        make_pr_curves(tlabel2classifications, plot_filepath_pr,
                       plot_ci, plot_ci_alpha)
    if plot_filepath_roc:
        log.info("Making ROC curve")
        make_roc_curves(tlabel2classifications, plot_filepath_roc,
                        plot_ci, plot_ci_alpha)
Exemple #22
0
    def classify(self):
        """
        Given a refined session ID and some number of descriptor UUIDs, create
        a classifier according to the current state and classify the given
        descriptors adjudicated.

        This will fail if the session has not been given adjudications
        (refined) yet.

        URI Args:
            sid
                UUID of the session to utilize
            uuids
                List of descriptor UUIDs to classify. Return list of results
                will be in the same order as this list.

        """
        # Record clean/dirty status after making classifier/refining so we
        # don't train a new classifier when we don't have to.
        sid = flask.request.args.get('sid', None)
        uuids = flask.request.args.get('uuids', None)

        try:
            uuids = json.loads(uuids)
        except ValueError:
            return make_response_json(
                "Failed to decode uuids as json. Given '%s'" % uuids), 400

        if sid is None:
            return make_response_json("No session id (sid) provided"), 400
        if not uuids:
            return make_response_json(
                "No descriptor UUIDs provided",
                sid=sid,
            ), 400

        try:
            with self.controller.get_session(sid) as iqrs:
                if not iqrs.positive_descriptors:
                    return make_response_json(
                        "No positive labels in current session", sid=sid), 400
                if not iqrs.negative_descriptors:
                    return make_response_json(
                        "No negative labels in current session", sid=sid), 400

                # Get descriptor elements for classification
                try:
                    descriptors = list(
                        self.descriptor_index.get_many_descriptors(uuids))
                except KeyError, ex:
                    err_uuid = str(ex)
                    self._log.warn(traceback.format_exc())
                    return make_response_json(
                        "Descriptor UUID '%s' cannot be found in the "
                        "configured descriptor index." % err_uuid,
                        sid=sid,
                        uuid=err_uuid,
                    ), 404

                classifier = self.session_classifiers.get(sid, None)

                pos_label = "positive"
                neg_label = "negative"
                if self.session_classifier_dirty[sid] or classifier is None:
                    self._log.debug("Training new classifier for current "
                                    "refine state")

                    #: :type: SupervisedClassifier
                    classifier = plugin.from_plugin_config(
                        self.classifier_config,
                        get_classifier_impls(
                            sub_interface=SupervisedClassifier))
                    classifier.train({
                        pos_label: iqrs.positive_descriptors,
                        neg_label: iqrs.negative_descriptors
                    })

                    self.session_classifiers[sid] = classifier
                    self.session_classifier_dirty[sid] = False

                classifications = classifier.classify_async(
                    descriptors,
                    self.classification_factory,
                    use_multiprocessing=True,
                    ri=1.0)

                # Format output to be parallel lists of UUIDs input and
                # positive class classification scores.
                o_uuids = []
                o_proba = []
                for d in descriptors:
                    o_uuids.append(d.uuid())
                    o_proba.append(classifications[d][pos_label])

                assert uuids == o_uuids, \
                    "Output UUID list is not congruent with INPUT list."

                return make_response_json(
                    "Finished classification",
                    sid=sid,
                    uuids=o_uuids,
                    proba=o_proba,
                ), 200

        except KeyError:
            return make_response_json("session id '%s' not found" % sid,
                                      sid=sid), 404
Exemple #23
0
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.legend(loc='best', fancybox=True, framealpha=0.5)
    plt.savefig(PLOT_PR_OUTPUT)

else:
    # Using the final trained classifier
    with open(CLASSIFIER_TRAINING_CONFIG_JSON) as f:
        classifier_config = json.load(f)

    log.info("Loading plugins")
    descriptor_index = MemoryDescriptorIndex(
        file_cache=DESCRIPTOR_INDEX_FILE_CACHE)
    #: :type: smqtk.algorithms.Classifier
    classifier = from_plugin_config(classifier_config['plugins']['classifier'],
                                    get_classifier_impls())
    c_factory = ClassificationElementFactory(MemoryClassificationElement, {})

    #: :type: dict[str, list[str]]
    phone2shas = json.load(open(PHONE_SHA1_JSON))
    #: :type: dict[str, float]
    phone2score = {}

    log.info("Classifying phone imagery descriptors")
    i = 0
    descriptor_index_shas = set(descriptor_index.iterkeys())
    for p in phone2shas:
        log.info('%s (%d / %d)', p, i + 1, len(phone2shas))
        # Not all source "images" have descriptors since some URLs returned
        # non-image files. Intersect phone sha's with what was actually
        # computed. Warn if this reduces descriptors for classification to zero.
Exemple #24
0
                sid=sid,
                uuid=err_uuid,
            ), 404

        classifier = self.session_classifiers.get(sid, None)

        pos_label = "positive"
        neg_label = "negative"
        if self.session_classifier_dirty[sid] or classifier is None:
            self._log.debug("Training new classifier for current "
                            "refine state")

            #: :type: SupervisedClassifier
            classifier = plugin.from_plugin_config(
                self.classifier_config,
                get_classifier_impls(sub_interface=SupervisedClassifier))
            classifier.train({
                pos_label: iqrs.positive_descriptors,
                neg_label: iqrs.negative_descriptors
            })

            self.session_classifiers[sid] = classifier
            self.session_classifier_dirty[sid] = False

        classifications = classifier.classify_async(
            descriptors,
            self.classification_factory,
            use_multiprocessing=True,
            ri=1.0)

        # Format output to be parallel lists of UUIDs input and
def main():
    description = """
    Script for asynchronously computing classifications for DescriptorElements
    in a DescriptorIndex specified via a list of UUIDs. Results are output to a
    CSV file in the format:

        uuid, label1_confidence, label2_confidence, ...

    CSV columns labels are output to the given CSV header file path. Label
    columns will be in the order as reported by the classifier implementations
    ``get_labels`` method.

    Due to using an input file-list of UUIDs, we require that the UUIDs of
    indexed descriptors be strings, or equality comparable to the UUIDs' string
    representation.
    """

    args, config = bin_utils.utility_main_helper(
        default_config,
        description,
        extend_parser,
    )
    log = logging.getLogger(__name__)

    # - parallel_map UUIDs to load from the configured index
    # - classify iterated descriptors

    uuids_list_filepath = args.uuids_list
    output_csv_filepath = args.csv_data
    output_csv_header_filepath = args.csv_header
    classify_overwrite = config['utility']['classify_overwrite']

    p_use_multiprocessing = \
        config['utility']['parallel']['use_multiprocessing']
    p_index_extraction_cores = \
        config['utility']['parallel']['index_extraction_cores']
    p_classification_cores = \
        config['utility']['parallel']['classification_cores']

    if not uuids_list_filepath:
        raise ValueError("No uuids_list_filepath specified.")
    elif not os.path.isfile(uuids_list_filepath):
        raise ValueError("Given uuids_list_filepath did not point to a file.")
    if output_csv_header_filepath is None:
        raise ValueError("Need a path to save CSV header labels")
    if output_csv_filepath is None:
        raise ValueError("Need a path to save CSV data.")

    #
    # Initialize configured plugins
    #

    log.info("Initializing descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )

    log.info("Initializing classification factory")
    c_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory']
    )

    log.info("Initializing classifier")
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(
        config['plugins']['classifier'], get_classifier_impls()
    )

    #
    # Setup/Process
    #
    def iter_uuids():
        with open(uuids_list_filepath) as f:
            for l in f:
                yield l.strip()

    def descr_for_uuid(uuid):
        """
        :type uuid: collections.Hashable
        :rtype: smqtk.representation.DescriptorElement
        """
        return descriptor_index.get_descriptor(uuid)

    def classify_descr(d):
        """
        :type d: smqtk.representation.DescriptorElement
        :rtype: smqtk.representation.ClassificationElement
        """
        return classifier.classify(d, c_factory, classify_overwrite)

    log.info("Initializing uuid-to-descriptor parallel map")
    #: :type: collections.Iterable[smqtk.representation.DescriptorElement]
    element_iter = parallel.parallel_map(
        descr_for_uuid, iter_uuids(),
        use_multiprocessing=p_use_multiprocessing,
        cores=p_index_extraction_cores,
        name="descr_for_uuid",
    )

    log.info("Initializing descriptor-to-classification parallel map")
    #: :type: collections.Iterable[smqtk.representation.ClassificationElement]
    classification_iter = parallel.parallel_map(
        classify_descr, element_iter,
        use_multiprocessing=p_use_multiprocessing,
        cores=p_classification_cores,
        name='classify_descr',
    )

    #
    # Write/Output files
    #

    c_labels = classifier.get_labels()

    def make_row(c):
        """
        :type c: smqtk.representation.ClassificationElement
        """
        c_m = c.get_classification()
        return [c.uuid] + [c_m[l] for l in c_labels]

    # column labels file
    log.info("Writing CSV column header file: %s", output_csv_header_filepath)
    file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath))
    with open(output_csv_header_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        w.writerow(['uuid'] + c_labels)

    # CSV file
    log.info("Writing CSV data file: %s", output_csv_filepath)
    file_utils.safe_create_dir(os.path.dirname(output_csv_filepath))
    r_state = [0] * 7
    with open(output_csv_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        for c in classification_iter:
            w.writerow(make_row(c))
            bin_utils.report_progress(log.info, r_state, 1.0)

    # Final report
    r_state[1] -= 1
    bin_utils.report_progress(log.info, r_state, 0)

    log.info("Done")
Exemple #26
0
def get_default_config():
    return {
        "classifier": make_config(get_classifier_impls()),
    }
Exemple #27
0
def classify_files(config, label, file_globs):
    log = logging.getLogger(__name__)

    #: :type: smqtk.algorithms.Classifier
    classifier = \
        plugin.from_plugin_config(config['classifier'],
                                  get_classifier_impls())

    def log_avaialable_labels():
        log.info("Available classifier labels:")
        for l in classifier.get_labels():
            log.info("- %s", l)

    if label is None:
        log_avaialable_labels()
        return
    elif label not in classifier.get_labels():
        log.error(
            "Invalid classification label provided to compute and filter "
            "on: '%s'", label)
        log_avaialable_labels()
        return

    log.info("Collecting files from globs")
    #: :type: list[DataFileElement]
    data_elements = []
    uuid2filepath = {}
    for g in file_globs:
        if os.path.isfile(g):
            d = DataFileElement(g)
            data_elements.append(d)
            uuid2filepath[d.uuid()] = g
        else:
            log.debug("expanding glob: %s", g)
            for fp in glob.iglob(g):
                d = DataFileElement(fp)
                data_elements.append(d)
                uuid2filepath[d.uuid()] = fp
    if not data_elements:
        raise RuntimeError("No files provided for classification.")

    log.info("Computing descriptors")
    descriptor_factory = \
        DescriptorElementFactory.from_config(config['descriptor_factory'])
    #: :type: smqtk.algorithms.DescriptorGenerator
    descriptor_generator = \
        plugin.from_plugin_config(config['descriptor_generator'],
                                  get_descriptor_generator_impls())
    descr_map = descriptor_generator\
        .compute_descriptor_async(data_elements, descriptor_factory)

    log.info("Classifying descriptors")
    classification_factory = ClassificationElementFactory \
        .from_config(config['classification_factory'])
    classification_map = classifier\
        .classify_async(descr_map.values(), classification_factory)

    log.info("Printing input file paths that classified as the given label.")
    # map of UUID to filepath:
    uuid2c = dict((c.uuid, c) for c in classification_map.itervalues())
    for data in data_elements:
        if uuid2c[data.uuid()].max_label() == label:
            print uuid2filepath[data.uuid()]
Exemple #28
0
    def classify(self):
        """
        Given a refined session ID and some number of descriptor UUIDs, create
        a classifier according to the current state and classify the given
        descriptors adjudicated.

        This will fail if the session has not been given adjudications
        (refined) yet.

        URI Args:
            sid
                UUID of the session to utilize
            uuids
                List of descriptor UUIDs to classify. Return list of results
                will be in the same order as this list.

        """
        # Record clean/dirty status after making classifier/refining so we
        # don't train a new classifier when we don't have to.
        sid = flask.request.args.get('sid', None)
        uuids = flask.request.args.get('uuids', None)

        try:
            uuids = json.loads(uuids)
        except ValueError:
            return make_response_json(
                "Failed to decode uuids as json. Given '%s'"
                % uuids
            ), 400

        if sid is None:
            return make_response_json("No session id (sid) provided"), 400
        if not uuids:
            return make_response_json(
                "No descriptor UUIDs provided",
                sid=sid,
            ), 400

        try:
            with self.controller.get_session(sid) as iqrs:
                if not iqrs.positive_descriptors:
                    return make_response_json(
                        "No positive labels in current session",
                        sid=sid
                    ), 400
                if not iqrs.negative_descriptors:
                    return make_response_json(
                        "No negative labels in current session",
                        sid=sid
                    ), 400

                # Get descriptor elements for classification
                try:
                    descriptors = list(self.descriptor_index
                                           .get_many_descriptors(uuids))
                except KeyError, ex:
                    err_uuid = str(ex)
                    self._log.warn(traceback.format_exc())
                    return make_response_json(
                        "Descriptor UUID '%s' cannot be found in the "
                        "configured descriptor index."
                        % err_uuid,
                        sid=sid,
                        uuid=err_uuid,
                    ), 404

                classifier = self.session_classifiers.get(sid, None)

                pos_label = "positive"
                neg_label = "negative"
                if self.session_classifier_dirty[sid] or classifier is None:
                    self._log.debug("Training new classifier for current "
                                    "refine state")

                    #: :type: SupervisedClassifier
                    classifier = plugin.from_plugin_config(
                        self.classifier_config,
                        get_classifier_impls(sub_interface=SupervisedClassifier)
                    )
                    classifier.train(
                        {pos_label: iqrs.positive_descriptors,
                         neg_label: iqrs.negative_descriptors}
                    )

                    self.session_classifiers[sid] = classifier
                    self.session_classifier_dirty[sid] = False

                classifications = classifier.classify_async(
                    descriptors, self.classification_factory,
                    use_multiprocessing=True, ri=1.0
                )

                # Format output to be parallel lists of UUIDs input and
                # positive class classification scores.
                o_uuids = []
                o_proba = []
                for d in descriptors:
                    o_uuids.append(d.uuid())
                    o_proba.append(classifications[d][pos_label])

                assert uuids == o_uuids, \
                    "Output UUID list is not congruent with INPUT list."

                return make_response_json(
                    "Finished classification",
                    sid=sid,
                    uuids=o_uuids,
                    proba=o_proba,
                ), 200

        except KeyError:
            return make_response_json("session id '%s' not found" % sid,
                                      sid=sid), 404
def get_supervised_classifier_impls():
    return get_classifier_impls(sub_interface=SupervisedClassifier)
Exemple #30
0
def main():
    description = """
    Utility for validating a given classifier implementation's model against
    some labeled testing data, outputting PR and ROC curve plots with
    area-under-curve score values.

    This utility can optionally be used train a supervised classifier model if
    the given classifier model configuration does not exist and a second CSV
    file listing labeled training data is provided. Training will be attempted
    if ``train`` is set to true. If training is performed, we exit after
    training completes. A ``SupervisedClassifier`` sub-classing implementation
    must be configured

    We expect the test and train CSV files in the column format:

        ...
        <UUID>,<label>
        ...

    The UUID is of the descriptor to which the label applies. The label may be
    any arbitrary string value, but all labels must be consistent in
    application.

    Some metrics presented assume the highest confidence class as the single
    predicted class for an element:

        - confusion matrix

    The output UUID confusion matrix is a JSON dictionary where the top-level
    keys are the true labels, and the inner dictionary is the mapping of
    predicted labels to the UUIDs of the classifications/descriptors that
    yielded the prediction. Again, this is based on the maximum probability
    label for a classification result (T=0.5).
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    #
    # Initialize stuff from configuration
    #
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(config['plugins']['classifier'],
                                           get_classifier_impls())
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'], get_descriptor_index_impls())

    uuid2label_filepath = config['utility']['csv_filepath']
    do_train = config['utility']['train']
    output_uuid_cm = config['utility']['output_uuid_confusion_matrix']
    plot_filepath_pr = config['utility']['output_plot_pr']
    plot_filepath_roc = config['utility']['output_plot_roc']
    plot_filepath_cm = config['utility']['output_plot_confusion_matrix']
    plot_ci = config['utility']['curve_confidence_interval']
    plot_ci_alpha = config['utility']['curve_confidence_interval_alpha']

    #
    # Construct mapping of label to the DescriptorElement instances for that
    # described by that label.
    #
    log.info("Loading descriptors by UUID")

    def iter_uuid_label():
        """ Iterate through UUIDs in specified file """
        with open(uuid2label_filepath) as uuid2label_file:
            reader = csv.reader(uuid2label_file)
            for r in reader:
                # TODO: This will need to be updated to handle multiple labels
                #       per descriptor.
                yield r[0], r[1]

    def get_descr(r):
        """ Fetch descriptors from configured index """
        uuid, truth_label = r
        return truth_label, descriptor_index.get_descriptor(uuid)

    tlabel_element_iter = parallel.parallel_map(
        get_descr,
        iter_uuid_label(),
        name="cmv_get_descriptors",
        use_multiprocessing=True,
        cores=config['parallelism']['descriptor_fetch_cores'],
    )

    # Map of truth labels to descriptors of labeled data
    #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
    tlabel2descriptors = {}
    for tlabel, d in tlabel_element_iter:
        tlabel2descriptors.setdefault(tlabel, []).append(d)

    # Train classifier if the one given has a ``train`` method and training
    # was turned enabled.
    if do_train:
        if isinstance(classifier, SupervisedClassifier):
            log.info("Training classifier model")
            classifier.train(tlabel2descriptors)
            exit(0)
        else:
            ValueError("Configured classifier is not a SupervisedClassifier "
                       "type and does not support training.")

    #
    # Apply classifier to descriptors for predictions
    #

    # Truth label to predicted classification results
    #: :type: dict[str, set[smqtk.representation.ClassificationElement]]
    tlabel2classifications = {}
    for tlabel, descriptors in tlabel2descriptors.iteritems():
        tlabel2classifications[tlabel] = \
            set(classifier.classify_async(
                descriptors, classification_factory,
                use_multiprocessing=True,
                procs=config['parallelism']['classification_cores'],
                ri=1.0,
            ).values())
    log.info("Truth label counts:")
    for l in sorted(tlabel2classifications):
        log.info("  %s :: %d", l, len(tlabel2classifications[l]))

    #
    # Confusion Matrix
    #
    conf_mat, labels = gen_confusion_matrix(tlabel2classifications)
    log.info("Confusion_matrix")
    log_cm(log.info, conf_mat, labels)
    if plot_filepath_cm:
        plot_cm(conf_mat, labels, plot_filepath_cm)

    # CM of descriptor UUIDs to output json
    if output_uuid_cm:
        # Top dictionary keys are true labels, inner dictionary keys are UUID
        # predicted labels.
        log.info("Computing UUID Confusion Matrix")
        #: :type: dict[str, dict[str, set | list]]
        uuid_cm = {}
        for tlabel in tlabel2classifications:
            uuid_cm[tlabel] = collections.defaultdict(set)
            for c in tlabel2classifications[tlabel]:
                uuid_cm[tlabel][c.max_label()].add(c.uuid)
            # convert sets to lists
            for plabel in uuid_cm[tlabel]:
                uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel])
        with open(output_uuid_cm, 'w') as f:
            log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm)
            json.dump(uuid_cm, f, indent=2, separators=(',', ': '))

    #
    # Create PR/ROC curves via scikit learn tools
    #
    if plot_filepath_pr:
        log.info("Making PR curve")
        make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci,
                       plot_ci_alpha)
    if plot_filepath_roc:
        log.info("Making ROC curve")
        make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci,
                        plot_ci_alpha)