Example #1
0
 def from_config(cls, config_dict, merge_default=True):
     config_dict = dict(config_dict)  # shallow copy to write to input dict
     config_dict['classifier_inst'] = \
         from_config_dict(config_dict.get('classifier_inst', {}),
                          SupervisedClassifier.get_impls())
     return super(SupervisedClassifierRelevancyIndex,
                  cls).from_config(config_dict, merge_default=merge_default)
def default_config():
    return {
        'plugins': {
            'classifier':
            make_default_config(SupervisedClassifier.get_impls()),
            'classification_factory':
            ClassificationElementFactory.get_default_config(),
            'descriptor_set':
            make_default_config(DescriptorSet.get_impls())
        },
        'utility': {
            'train': False,
            'csv_filepath': 'CHAMGEME :: PATH :: a csv file',
            'output_plot_pr': None,
            'output_plot_roc': None,
            'output_plot_confusion_matrix': None,
            'output_uuid_confusion_matrix': None,
            'curve_confidence_interval': False,
            'curve_confidence_interval_alpha': 0.4,
        },
        "parallelism": {
            "descriptor_fetch_cores": 4,
            # DEPRECATED
            "classification_cores": None,
        },
    }
Example #3
0
    def get_default_config(cls):
        c = super(SmqtkClassifierService, cls).get_default_config()

        c[cls.CONFIG_ENABLE_CLASSIFIER_REMOVAL] = False

        # Static classifier configurations
        c[cls.CONFIG_CLASSIFIER_COLLECTION] = \
            ClassifierCollection.get_default_config()
        # Classification element factory for new classification results.
        c[cls.CONFIG_CLASSIFICATION_FACTORY] = \
            ClassificationElementFactory.get_default_config()
        # Descriptor generator for new content
        c[cls.CONFIG_DESCRIPTOR_GENERATOR] = make_default_config(
            DescriptorGenerator.get_impls()
        )
        # Descriptor factory for new content descriptors
        c[cls.CONFIG_DESCRIPTOR_FACTORY] = \
            DescriptorElementFactory.get_default_config()
        # Optional Descriptor set for "included" descriptors referenceable by
        # UID.
        c[cls.CONFIG_DESCRIPTOR_SET] = make_default_config(
            DescriptorSet.get_impls()
        )
        # from-IQR-state *supervised* classifier configuration
        c[cls.CONFIG_IQR_CLASSIFIER] = make_default_config(
            SupervisedClassifier.get_impls()
        )
        c[cls.CONFIG_IMMUTABLE_LABELS] = []

        return c
Example #4
0
def train_classifier_iqr(config, iqr_state_fp):
    #: :type: smqtk.algorithms.SupervisedClassifier
    classifier = from_config_dict(config['classifier'],
                                  SupervisedClassifier.get_impls())

    # Load state into an empty IqrSession instance.
    with open(iqr_state_fp, 'rb') as f:
        state_bytes = f.read().strip()
    descr_factory = DescriptorElementFactory(DescriptorMemoryElement, {})
    iqrs = IqrSession()
    iqrs.set_state_bytes(state_bytes, descr_factory)

    # Positive descriptor examples for training are composed of those from
    # external and internal sets. Same for negative descriptor examples.
    pos = iqrs.positive_descriptors | iqrs.external_positive_descriptors
    neg = iqrs.negative_descriptors | iqrs.external_negative_descriptors
    classifier.train(class_examples={'positive': pos, 'negative': neg})
Example #5
0
 def get_default_config(cls):
     c = super(SupervisedClassifierRelevancyIndex, cls).get_default_config()
     c['classifier_inst'] = \
         make_default_config(SupervisedClassifier.get_impls())
     return c
Example #6
0
def get_default_config():
    return {
        "classifier": make_default_config(SupervisedClassifier.get_impls()),
    }
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    # Deprecations
    if (config.get('parallelism', {}).get('classification_cores', None)
            is not None):
        warnings.warn(
            "Usage of 'classification_cores' is deprecated. "
            "Classifier parallelism is not defined on a "
            "per-implementation basis. See classifier "
            "implementation parameterization.",
            category=DeprecationWarning)

    #
    # Initialize stuff from configuration
    #
    #: :type: smqtk.algorithms.Classifier
    classifier = from_config_dict(config['plugins']['classifier'],
                                  SupervisedClassifier.get_impls())
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])
    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(config['plugins']['descriptor_set'],
                                      DescriptorSet.get_impls())

    uuid2label_filepath = config['utility']['csv_filepath']
    do_train = config['utility']['train']
    output_uuid_cm = config['utility']['output_uuid_confusion_matrix']
    plot_filepath_pr = config['utility']['output_plot_pr']
    plot_filepath_roc = config['utility']['output_plot_roc']
    plot_filepath_cm = config['utility']['output_plot_confusion_matrix']
    plot_ci = config['utility']['curve_confidence_interval']
    plot_ci_alpha = config['utility']['curve_confidence_interval_alpha']

    #
    # Construct mapping of label to the DescriptorElement instances for that
    # described by that label.
    #
    log.info("Loading descriptors by UUID")

    def iter_uuid_label():
        """ Iterate through UUIDs in specified file """
        with open(uuid2label_filepath) as uuid2label_file:
            reader = csv.reader(uuid2label_file)
            for r in reader:
                # TODO: This will need to be updated to handle multiple labels
                #       per descriptor.
                yield r[0], r[1]

    def get_descr(r):
        """ Fetch descriptors from configured index """
        uuid, truth_label = r
        return truth_label, descriptor_set.get_descriptor(uuid)

    tlabel_element_iter = parallel.parallel_map(
        get_descr,
        iter_uuid_label(),
        name="cmv_get_descriptors",
        use_multiprocessing=True,
        cores=config['parallelism']['descriptor_fetch_cores'],
    )

    # Map of truth labels to descriptors of labeled data
    #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
    tlabel2descriptors = {}
    for tlabel, d in tlabel_element_iter:
        tlabel2descriptors.setdefault(tlabel, []).append(d)

    # Train classifier if the one given has a ``train`` method and training
    # was turned enabled.
    if do_train:
        log.info("Training supervised classifier model")
        classifier.train(tlabel2descriptors)
        exit(0)

    #
    # Apply classifier to descriptors for predictions
    #

    # Truth label to predicted classification results
    #: :type: dict[str, set[smqtk.representation.ClassificationElement]]
    tlabel2classifications = {}
    for tlabel, descriptors in six.iteritems(tlabel2descriptors):
        tlabel2classifications[tlabel] = \
            set(classifier.classify_elements(descriptors,
                                             classification_factory))
    log.info("Truth label counts:")
    for l in sorted(tlabel2classifications):
        log.info("  %s :: %d", l, len(tlabel2classifications[l]))

    #
    # Confusion Matrix
    #
    conf_mat, labels = gen_confusion_matrix(tlabel2classifications)
    log.info("Confusion_matrix")
    log_cm(log.info, conf_mat, labels)
    if plot_filepath_cm:
        plot_cm(conf_mat, labels, plot_filepath_cm)

    # Confusion Matrix of descriptor UUIDs to output json
    if output_uuid_cm:
        # Top dictionary keys are true labels, inner dictionary keys are UUID
        # predicted labels.
        log.info("Computing UUID Confusion Matrix")
        #: :type: dict[str, dict[collections.Hashable, set]]
        uuid_cm = {}
        for tlabel in tlabel2classifications:
            uuid_cm[tlabel] = collections.defaultdict(set)
            for c in tlabel2classifications[tlabel]:
                uuid_cm[tlabel][c.max_label()].add(c.uuid)
            # convert sets to lists for JSON output.
            for plabel in uuid_cm[tlabel]:
                # noinspection PyTypeChecker
                uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel])
        with open(output_uuid_cm, 'w') as f:
            log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm)
            json.dump(uuid_cm, f, indent=2, separators=(',', ': '))

    #
    # Create PR/ROC curves via scikit learn tools
    #
    if plot_filepath_pr:
        log.info("Making PR curve")
        make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci,
                       plot_ci_alpha)
    if plot_filepath_roc:
        log.info("Making ROC curve")
        make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci,
                        plot_ci_alpha)
Example #8
0
    def add_iqr_state_classifier(self):
        """
        Train a classifier based on the user-provided IQR state file bytes in
        a base64 encoding, matched with a descriptive label of that
        classifier's topic.

        Since all IQR session classifiers end up only having two result
        classes (positive and negative), the topic of the classifier is
        encoded in the descriptive label the user applies to the classifier.

        Below is an example call to this endpoint via the ``requests`` python
        module, showing how base64 data is sent::

            import base64
            import requests
            data_bytes = "Load some content bytes here."
            requests.get('http://localhost:5000/iqr_classifier',
                         data={'bytes_b64': base64.b64encode(data_bytes),
                               'label': 'some_label'})

        With curl on the command line::

            $ curl -X POST localhost:5000/iqr_classifier \
                -d "label=some_label" \
                --data-urlencode "bytes_b64=$(base64 -w0 /path/to/file)"

            # If this fails, you may wish to encode the file separately and
            # use the file reference syntax instead:

            $ base64 -w0 /path/to/file > /path/to/file.b64
            $ curl -X POST localhost:5000/iqr_classifier -d label=some_label \
                --data-urlencode bytes_64@/path/to/file.b64

        To lock this classifier and guard it against deletion, add
        "lock_label=true"::

            $ curl -X POST localhost:5000/iqr_classifier \
                -d "label=some_label" \
                -d "lock_label=true" \
                --data-urlencode "bytes_b64=$(base64 -w0 /path/to/file)"

        Form arguments:
            iqr_state_b64
                base64 encoding of the bytes of the IQR session state save
                file.
            label
                Descriptive label to apply to this classifier. This should not
                conflict with existing classifier labels.
            lock_label
                If 'true', disallow deletion of this label. If 'false', allow
                deletion of this label. Only has an effect if deletion is
                enabled for this service. (Default: 'false')

        Returns 201.

        """
        data_b64 = flask.request.values.get('bytes_b64', default=None)
        label = flask.request.values.get('label', default=None)
        lock_clfr_str = flask.request.values.get('lock_label', default='false')

        if data_b64 is None or len(data_b64) == 0:
            return make_response_json("No state base64 data provided.", 400)
        elif label is None or len(label) == 0:
            return make_response_json("No descriptive label provided.", 400)
        try:
            lock_clfr = bool(flask.json.loads(lock_clfr_str))
        except JSON_DECODE_EXCEPTION:
            return make_response_json(
                "Invalid boolean value for"
                " 'lock_label'. Was given: '%s'" % lock_clfr_str, 400)
        try:
            # Using urlsafe version because it handles both regular and urlsafe
            # alphabets.
            data_bytes = base64.urlsafe_b64decode(data_b64.encode('utf-8'))
        except (TypeError, binascii.Error) as ex:
            return make_response_json("Invalid base64 input: %s" % str(ex)), \
                   400

        # If the given label conflicts with one already in the collection,
        # fail.
        if label in self.classifier_collection.labels():
            return make_response_json(
                "Label already exists in classifier collection.", 400)

        # Create dummy IqrSession to extract pos/neg descriptors.
        iqrs = IqrSession()
        iqrs.set_state_bytes(data_bytes, self.descriptor_factory)
        pos = iqrs.positive_descriptors | iqrs.external_positive_descriptors
        neg = iqrs.negative_descriptors | iqrs.external_negative_descriptors
        del iqrs

        # Make a classifier instance from the stored config for IQR
        # session-based classifiers.
        #: :type: SupervisedClassifier
        classifier = from_config_dict(self.iqr_state_classifier_config,
                                      SupervisedClassifier.get_impls())
        classifier.train(class_examples={'positive': pos, 'negative': neg})

        try:
            self.classifier_collection.add_classifier(label, classifier)

            # If we're allowing deletions, get the lock flag from the form and
            # set it for this classifier
            if self.enable_classifier_removal and lock_clfr:
                self.immutable_labels.add(label)

        except ValueError as e:
            if e.args[0].find('JSON') > -1:
                return make_response_json(
                    "Tried to parse malformed JSON in "
                    "form argument.", 400)
            return make_response_json("Duplicate label ('%s') added during "
                                      "classifier training of provided IQR "
                                      "session state." % label,
                                      400,
                                      label=label)

        return make_response_json("Finished training IQR-session-based "
                                  "classifier for label '%s'." % label,
                                  201,
                                  label=label)