コード例 #1
0
def train_classifier_iqr(config, iqr_state_fp):
    #: :type: smqtk.algorithms.SupervisedClassifier
    classifier = from_config_dict(config['classifier'],
                                  SupervisedClassifier.get_impls())

    # Load state into an empty IqrSession instance.
    with open(iqr_state_fp, 'rb') as f:
        state_bytes = f.read().strip()
    descr_factory = DescriptorElementFactory(DescriptorMemoryElement, {})
    iqrs = IqrSession()
    iqrs.set_state_bytes(state_bytes, descr_factory)

    # Positive descriptor examples for training are composed of those from
    # external and internal sets. Same for negative descriptor examples.
    pos = iqrs.positive_descriptors | iqrs.external_positive_descriptors
    neg = iqrs.negative_descriptors | iqrs.external_negative_descriptors
    classifier.train(class_examples={'positive': pos, 'negative': neg})
コード例 #2
0
ファイル: iqrTrainClassifier.py プロジェクト: Kitware/SMQTK
def train_classifier_iqr(config, iqr_state_fp):
    #: :type: smqtk.algorithms.SupervisedClassifier
    classifier = from_plugin_config(
        config['classifier'],
        get_classifier_impls(sub_interface=SupervisedClassifier)
    )

    # Load state into an empty IqrSession instance.
    with open(iqr_state_fp, 'rb') as f:
        state_bytes = f.read().strip()
    descr_factory = DescriptorElementFactory(DescriptorMemoryElement, {})
    iqrs = IqrSession()
    iqrs.set_state_bytes(state_bytes, descr_factory)

    # Positive descriptor examples for training are composed of those from
    # external and internal sets. Same for negative descriptor examples.
    pos = iqrs.positive_descriptors | iqrs.external_positive_descriptors
    neg = iqrs.negative_descriptors | iqrs.external_negative_descriptors
    classifier.train(class_examples={'positive': pos, 'negative': neg})
コード例 #3
0
    def add_iqr_state_classifier(self):
        """
        Train a classifier based on the user-provided IQR state file bytes in
        a base64 encoding, matched with a descriptive label of that
        classifier's topic.

        Since all IQR session classifiers end up only having two result
        classes (positive and negative), the topic of the classifier is
        encoded in the descriptive label the user applies to the classifier.

        Below is an example call to this endpoint via the ``requests`` python
        module, showing how base64 data is sent::

            import base64
            import requests
            data_bytes = "Load some content bytes here."
            requests.get('http://localhost:5000/iqr_classifier',
                         data={'bytes_b64': base64.b64encode(data_bytes),
                               'label': 'some_label'})

        With curl on the command line::

            $ curl -X POST localhost:5000/iqr_classifier \
                -d "label=some_label" \
                --data-urlencode "bytes_b64=$(base64 -w0 /path/to/file)"

            # If this fails, you may wish to encode the file separately and
            # use the file reference syntax instead:

            $ base64 -w0 /path/to/file > /path/to/file.b64
            $ curl -X POST localhost:5000/iqr_classifier -d label=some_label \
                --data-urlencode bytes_64@/path/to/file.b64

        To lock this classifier and guard it against deletion, add
        "lock_label=true"::

            $ curl -X POST localhost:5000/iqr_classifier \
                -d "label=some_label" \
                -d "lock_label=true" \
                --data-urlencode "bytes_b64=$(base64 -w0 /path/to/file)"

        Form arguments:
            iqr_state_b64
                base64 encoding of the bytes of the IQR session state save
                file.
            label
                Descriptive label to apply to this classifier. This should not
                conflict with existing classifier labels.
            lock_label
                If 'true', disallow deletion of this label. If 'false', allow
                deletion of this label. Only has an effect if deletion is
                enabled for this service. (Default: 'false')

        Returns 201.

        """
        data_b64 = flask.request.values.get('bytes_b64', default=None)
        label = flask.request.values.get('label', default=None)
        lock_clfr_str = flask.request.values.get('lock_label', default='false')

        if data_b64 is None or len(data_b64) == 0:
            return make_response_json("No state base64 data provided.", 400)
        elif label is None or len(label) == 0:
            return make_response_json("No descriptive label provided.", 400)
        try:
            lock_clfr = bool(flask.json.loads(lock_clfr_str))
        except JSON_DECODE_EXCEPTION:
            return make_response_json(
                "Invalid boolean value for"
                " 'lock_label'. Was given: '%s'" % lock_clfr_str, 400)
        try:
            # Using urlsafe version because it handles both regular and urlsafe
            # alphabets.
            data_bytes = base64.urlsafe_b64decode(data_b64.encode('utf-8'))
        except (TypeError, binascii.Error) as ex:
            return make_response_json("Invalid base64 input: %s" % str(ex)), \
                   400

        # If the given label conflicts with one already in the collection,
        # fail.
        if label in self.classifier_collection.labels():
            return make_response_json(
                "Label already exists in classifier collection.", 400)

        # Create dummy IqrSession to extract pos/neg descriptors.
        iqrs = IqrSession()
        iqrs.set_state_bytes(data_bytes, self.descriptor_factory)
        pos = iqrs.positive_descriptors | iqrs.external_positive_descriptors
        neg = iqrs.negative_descriptors | iqrs.external_negative_descriptors
        del iqrs

        # Make a classifier instance from the stored config for IQR
        # session-based classifiers.
        #: :type: SupervisedClassifier
        classifier = from_config_dict(self.iqr_state_classifier_config,
                                      SupervisedClassifier.get_impls())
        classifier.train(class_examples={'positive': pos, 'negative': neg})

        try:
            self.classifier_collection.add_classifier(label, classifier)

            # If we're allowing deletions, get the lock flag from the form and
            # set it for this classifier
            if self.enable_classifier_removal and lock_clfr:
                self.immutable_labels.add(label)

        except ValueError as e:
            if e.args[0].find('JSON') > -1:
                return make_response_json(
                    "Tried to parse malformed JSON in "
                    "form argument.", 400)
            return make_response_json("Duplicate label ('%s') added during "
                                      "classifier training of provided IQR "
                                      "session state." % label,
                                      400,
                                      label=label)

        return make_response_json("Finished training IQR-session-based "
                                  "classifier for label '%s'." % label,
                                  201,
                                  label=label)
コード例 #4
0
ファイル: classifier_server.py プロジェクト: Kitware/SMQTK
    def add_iqr_state_classifier(self):
        """
        Train a classifier based on the user-provided IQR state file bytes in
        a base64 encoding, matched with a descriptive label of that
        classifier's topic.

        Since all IQR session classifiers end up only having two result
        classes (positive and negative), the topic of the classifier is
        encoded in the descriptive label the user applies to the classifier.

        Below is an example call to this endpoint via the ``requests`` python
        module, showing how base64 data is sent::

            import base64
            import requests
            data_bytes = "Load some content bytes here."
            requests.get('http://localhost:5000/iqr_classifier',
                         data={'bytes_b64': base64.b64encode(data_bytes),
                               'label': 'some_label'})

        With curl on the command line::

            $ curl -X POST localhost:5000/iqr_classifier \
                -d "label=some_label" \
                --data-urlencode "bytes_b64=$(base64 -w0 /path/to/file)"

            # If this fails, you may wish to encode the file separately and
            # use the file reference syntax instead:

            $ base64 -w0 /path/to/file > /path/to/file.b64
            $ curl -X POST localhost:5000/iqr_classifier -d label=some_label \
                --data-urlencode bytes_64@/path/to/file.b64

        To lock this classifier and guard it against deletion, add
        "lock_label=true"::

            $ curl -X POST localhost:5000/iqr_classifier \
                -d "label=some_label" \
                -d "lock_label=true" \
                --data-urlencode "bytes_b64=$(base64 -w0 /path/to/file)"

        Form arguments:
            iqr_state_b64
                base64 encoding of the bytes of the IQR session state save
                file.
            label
                Descriptive label to apply to this classifier. This should not
                conflict with existing classifier labels.
            lock_label
                If 'true', disallow deletion of this label. If 'false', allow
                deletion of this label. Only has an effect if deletion is
                enabled for this service. (Default: 'false')

        Returns 201.

        """
        data_b64 = flask.request.values.get('bytes_b64', default=None)
        label = flask.request.values.get('label', default=None)
        lock_clfr_str = flask.request.values.get('lock_label',
                                                 default='false')

        if data_b64 is None or len(data_b64) == 0:
            return make_response_json("No state base64 data provided.", 400)
        elif label is None or len(label) == 0:
            return make_response_json("No descriptive label provided.", 400)
        try:
            lock_clfr = bool(flask.json.loads(lock_clfr_str))
        except JSON_DECODE_EXCEPTION:
            return make_response_json("Invalid boolean value for"
                                      " 'lock_label'. Was given: '%s'"
                                      % lock_clfr_str,
                                      400)
        try:
            # Using urlsafe version because it handles both regular and urlsafe
            # alphabets.
            data_bytes = base64.urlsafe_b64decode(data_b64.encode('utf-8'))
        except (TypeError, binascii.Error) as ex:
            return make_response_json("Invalid base64 input: %s" % str(ex)), \
                   400

        # If the given label conflicts with one already in the collection,
        # fail.
        if label in self.classifier_collection.labels():
            return make_response_json(
                "Label already exists in classifier collection.", 400)

        # Create dummy IqrSession to extract pos/neg descriptors.
        iqrs = IqrSession()
        iqrs.set_state_bytes(data_bytes, self.descriptor_factory)
        pos = iqrs.positive_descriptors | iqrs.external_positive_descriptors
        neg = iqrs.negative_descriptors | iqrs.external_negative_descriptors
        del iqrs

        # Make a classifier instance from the stored config for IQR
        # session-based classifiers.
        #: :type: SupervisedClassifier
        classifier = smqtk.utils.plugin.from_plugin_config(
            self.iqr_state_classifier_config,
            get_classifier_impls(sub_interface=SupervisedClassifier)
        )
        classifier.train(class_examples={'positive': pos, 'negative': neg})

        try:
            self.classifier_collection.add_classifier(label, classifier)

            # If we're allowing deletions, get the lock flag from the form and
            # set it for this classifier
            if self.enable_classifier_removal and lock_clfr:
                self.immutable_labels.add(label)

        except ValueError as e:
            if e.args[0].find('JSON') > -1:
                return make_response_json("Tried to parse malformed JSON in "
                                          "form argument.", 400)
            return make_response_json("Duplicate label ('%s') added during "
                                      "classifier training of provided IQR "
                                      "session state." % label, 400,
                                      label=label)

        return make_response_json("Finished training IQR-session-based "
                                  "classifier for label '%s'." % label,
                                  201,
                                  label=label)