def test_get_config(self):
     """
     We should be able to get the configuration of the current factory.
     This should look like the same as the
     """
     test_params = {'p1': 'some dir', 'vec': 1}
     factory = DescriptorElementFactory(DummyElementImpl, test_params)
     factory_config = factory.get_config()
     assert factory_config == {
         "type": "DummyElementImpl",
         "DummyElementImpl": test_params
     }
    def test_configuration(self):
        c = DescriptorElementFactory.get_default_config()
        self.assertIsNone(c['type'])
        self.assertIn('DescriptorMemoryElement', c)

        c['type'] = 'DescriptorMemoryElement'
        factory = DescriptorElementFactory.from_config(c)
        self.assertEqual(factory._d_type.__name__,
                         DescriptorMemoryElement.__name__)
        self.assertEqual(factory._d_type_config, {})

        d = factory.new_descriptor('test', 'foo')
        self.assertEqual(d.type(), 'test')
        self.assertEqual(d.uuid(), 'foo')
    def test_configuration(self):
        c = DescriptorElementFactory.get_default_config()
        ntools.assert_is_none(c['type'])
        ntools.assert_in('DescriptorMemoryElement', c)

        c['type'] = 'DescriptorMemoryElement'
        factory = DescriptorElementFactory.from_config(c)
        ntools.assert_equal(factory._d_type.__name__,
                            DescriptorMemoryElement.__name__)
        ntools.assert_equal(factory._d_type_config, {})

        d = factory.new_descriptor('test', 'foo')
        ntools.assert_equal(d.type(), 'test')
        ntools.assert_equal(d.uuid(), 'foo')
    def test_configuration(self):
        c = DescriptorElementFactory.get_default_config()
        ntools.assert_is_none(c['type'])
        ntools.assert_in('DescriptorMemoryElement', c)

        c['type'] = 'DescriptorMemoryElement'
        factory = DescriptorElementFactory.from_config(c)
        ntools.assert_equal(factory._d_type.__name__,
                            DescriptorMemoryElement.__name__)
        ntools.assert_equal(factory._d_type_config, {})

        d = factory.new_descriptor('test', 'foo')
        ntools.assert_equal(d.type(), 'test')
        ntools.assert_equal(d.uuid(), 'foo')
        DescriptorMemoryElement.MEMORY_CACHE = {}
    def test_configuration(self):
        c = DescriptorElementFactory.get_default_config()
        self.assertIsNone(c['type'])
        dme_key = 'smqtk.representation.descriptor_element.local_elements.DescriptorMemoryElement'
        self.assertIn(dme_key, c)

        c['type'] = dme_key
        factory = DescriptorElementFactory.from_config(c)
        self.assertEqual(factory._d_type.__name__,
                         DescriptorMemoryElement.__name__)
        self.assertEqual(factory._d_type_config, {})

        d = factory.new_descriptor('test', 'foo')
        self.assertEqual(d.type(), 'test')
        self.assertEqual(d.uuid(), 'foo')
示例#6
0
    def from_config(cls, config_dict, type_str, uuid, merge_default=True):
        # convert factory configuration
        config_dict["wrapped_element_factory"] = DescriptorElementFactory.from_config(
            config_dict["wrapped_element_factory"]
        )

        return super(CachingDescriptorElement, cls).from_config(config_dict, type_str, uuid, merge_default)
示例#7
0
def main():
    parser = cli_parser()
    args = parser.parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    output_filepath = args.output_filepath
    overwrite = args.overwrite

    if not args.input_file:
        log.error("Failed to provide an input file path")
        exit(1)
    elif not os.path.isfile(args.input_file):
        log.error("Given path does not point to a file.")
        exit(1)

    input_filepath = args.input_file
    data_element = DataFileElement(input_filepath)

    factory = DescriptorElementFactory.from_config(config['descriptor_factory'])
    #: :type: smqtk.algorithms.descriptor_generator.DescriptorGenerator
    cd = from_config_dict(config['content_descriptor'],
                          DescriptorGenerator.get_impls())

    vec = generate_vector(log, cd, data_element, factory, overwrite)

    if output_filepath:
        numpy.save(output_filepath, vec)
    else:
        # Construct string, because numpy
        s = []
        # noinspection PyTypeChecker
        for f in vec:
            s.append('%15f' % f)
        print(' '.join(s))
示例#8
0
    def get_default_config(cls):
        d = super(IqrSearch, cls).get_default_config()

        # Remove parent_app slot for later explicit specification.
        del d['parent_app']

        # fill in plugin configs
        d['data_set'] = plugin.make_config(get_data_set_impls())

        d['descr_generator'] = \
            plugin.make_config(get_descriptor_generator_impls())

        d['nn_index'] = plugin.make_config(get_nn_index_impls())

        ri_config = plugin.make_config(get_relevancy_index_impls())
        if d['rel_index_config']:
            ri_config.update(d['rel_index_config'])
        d['rel_index_config'] = ri_config

        df_config = DescriptorElementFactory.get_default_config()
        if d['descriptor_factory']:
            df_config.update(d['descriptor_factory'].get_config())
        d['descriptor_factory'] = df_config

        return d
    def test_no_params(self, dei_init):
        # So we don't break python
        dei_init.return_value = None

        test_params = {}

        factory = DescriptorElementFactory(DummyElementImpl, test_params)

        expected_type = 'type'
        expected_uuid = 'uuid'
        # Should construct a new DEI instance under they hood somewhere
        r = factory.new_descriptor(expected_type, expected_uuid)

        ntools.assert_true(dei_init.called)
        dei_init.assert_called_once_with(expected_type, expected_uuid)
        ntools.assert_is_instance(r, DummyElementImpl)
示例#10
0
    def from_config(cls, config, parent_app):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        :param config: JSON compliant dictionary encapsulating
            a configuration.
        :type config: dict

        :param parent_app: Parent containing flask app instance
        :type parent_app: smqtk.web.search_app.app.search_app

        :return: Constructed instance from the provided config.
        :rtype: IqrSearch

        """
        merged = cls.get_default_config()
        merged.update(config)

        # construct nested objects via configurations
        merged['data_set'] = \
            plugin.from_plugin_config(merged['data_set'],
                                      get_data_set_impls())
        merged['descr_generator'] = \
            plugin.from_plugin_config(merged['descr_generator'],
                                      get_descriptor_generator_impls())
        merged['nn_index'] = \
            plugin.from_plugin_config(merged['nn_index'],
                                      get_nn_index_impls())

        merged['descriptor_factory'] = \
            DescriptorElementFactory.from_config(merged['descriptor_factory'])

        return cls(parent_app, **merged)
示例#11
0
    def get_default_config(cls):
        d = super(IqrSearch, cls).get_default_config()

        # Remove parent_app slot for later explicit specification.
        del d['parent_app']

        # fill in plugin configs
        d['data_set'] = plugin.make_config(get_data_set_impls())

        d['descr_generator'] = \
            plugin.make_config(get_descriptor_generator_impls())

        d['nn_index'] = plugin.make_config(get_nn_index_impls())

        ri_config = plugin.make_config(get_relevancy_index_impls())
        if d['rel_index_config']:
            ri_config.update(d['rel_index_config'])
        d['rel_index_config'] = ri_config

        df_config = DescriptorElementFactory.get_default_config()
        if d['descriptor_factory']:
            df_config.update(d['descriptor_factory'].get_config())
        d['descriptor_factory'] = df_config

        return d
示例#12
0
    def get_default_config(cls):
        c = super(SmqtkClassifierService, cls).get_default_config()

        c[cls.CONFIG_ENABLE_CLASSIFIER_REMOVAL] = False

        # Static classifier configurations
        c[cls.CONFIG_CLASSIFIER_COLLECTION] = \
            ClassifierCollection.get_default_config()
        # Classification element factory for new classification results.
        c[cls.CONFIG_CLASSIFICATION_FACTORY] = \
            ClassificationElementFactory.get_default_config()
        # Descriptor generator for new content
        c[cls.CONFIG_DESCRIPTOR_GENERATOR] = make_default_config(
            DescriptorGenerator.get_impls()
        )
        # Descriptor factory for new content descriptors
        c[cls.CONFIG_DESCRIPTOR_FACTORY] = \
            DescriptorElementFactory.get_default_config()
        # Optional Descriptor set for "included" descriptors referenceable by
        # UID.
        c[cls.CONFIG_DESCRIPTOR_SET] = make_default_config(
            DescriptorSet.get_impls()
        )
        # from-IQR-state *supervised* classifier configuration
        c[cls.CONFIG_IQR_CLASSIFIER] = make_default_config(
            SupervisedClassifier.get_impls()
        )
        c[cls.CONFIG_IMMUTABLE_LABELS] = []

        return c
示例#13
0
    def get_default_config(cls):
        c = super(SmqtkClassifierService, cls).get_default_config()

        c[cls.CONFIG_ENABLE_CLASSIFIER_REMOVAL] = False

        # Static classifier configurations
        c[cls.CONFIG_CLASSIFIER_COLLECTION] = \
            ClassifierCollection.get_default_config()
        # Classification element factory for new classification results.
        c[cls.CONFIG_CLASSIFICATION_FACTORY] = \
            ClassificationElementFactory.get_default_config()
        # Descriptor generator for new content
        c[cls.CONFIG_DESCRIPTOR_GENERATOR] = smqtk.utils.plugin.make_config(
            get_descriptor_generator_impls()
        )
        # Descriptor factory for new content descriptors
        c[cls.CONFIG_DESCRIPTOR_FACTORY] = \
            DescriptorElementFactory.get_default_config()
        # from-IQR-state *supervised* classifier configuration
        c[cls.CONFIG_IQR_CLASSIFIER] = smqtk.utils.plugin.make_config(
            get_classifier_impls(
                sub_interface=SupervisedClassifier
            )
        )
        c[cls.CONFIG_IMMUTABLE_LABELS] = []

        return c
示例#14
0
    def from_config(cls, config, parent_app):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        :param config: JSON compliant dictionary encapsulating
            a configuration.
        :type config: dict

        :param parent_app: Parent containing flask app instance
        :type parent_app: smqtk.web.search_app.app.search_app

        :return: Constructed instance from the provided config.
        :rtype: IqrSearch

        """
        merged = cls.get_default_config()
        merged.update(config)

        # construct nested objects via configurations
        merged['data_set'] = \
            plugin.from_plugin_config(merged['data_set'],
                                      get_data_set_impls())
        merged['descr_generator'] = \
            plugin.from_plugin_config(merged['descr_generator'],
                                      get_descriptor_generator_impls())
        merged['nn_index'] = \
            plugin.from_plugin_config(merged['nn_index'],
                                      get_nn_index_impls())

        merged['descriptor_factory'] = \
            DescriptorElementFactory.from_config(merged['descriptor_factory'])

        return cls(parent_app, **merged)
示例#15
0
    def get_default_config(cls):
        """
        Generate and return a default configuration dictionary for this class.
        This will be primarily used for generating what the configuration
        dictionary would look like for this class without instantiating it.

        :return: Default configuration dictionary for the class.
        :rtype: dict

        """
        c = super(NearestNeighborServiceServer, cls).get_default_config()
        merge_dict(
            c, {
                "descriptor_factory":
                DescriptorElementFactory.get_default_config(),
                "descriptor_generator":
                plugin.make_config(get_descriptor_generator_impls()),
                "nn_index":
                plugin.make_config(get_nn_index_impls()),
                "descriptor_index":
                plugin.make_config(get_descriptor_index_impls()),
                "update_descriptor_index":
                False,
            })
        return c
示例#16
0
def run_file_list(c,
                  filelist_filepath,
                  checkpoint_filepath,
                  batch_size=None,
                  check_image=False):
    """
    Top level function handling configuration and inputs/outputs.

    :param c: Configuration dictionary (JSON)
    :type c: dict

    :param filelist_filepath: Path to a text file that lists paths to data
        files, separated by new lines.
    :type filelist_filepath: str

    :param checkpoint_filepath: Output file to which we write input filepath to
        SHA1 (UUID) relationships.
    :type checkpoint_filepath:

    :param batch_size: Optional batch size (None default) of data elements to
        process / descriptors to compute at a time. This causes files and
        stores to be written to incrementally during processing instead of
        one single batch transaction at a time.
    :type batch_size:

    :param check_image: Enable checking image loading from file before queueing
        that file for processing. If the check fails, the file is skipped
        instead of a halting exception being raised.
    :type check_image: bool

    """
    log = logging.getLogger(__name__)

    file_paths = [l.strip() for l in open(filelist_filepath)]

    log.info("Making descriptor factory")
    factory = DescriptorElementFactory.from_config(c['descriptor_factory'])

    log.info("Making descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(c['descriptor_index'],
                                                 get_descriptor_index_impls())

    log.info("Making descriptor generator '%s'",
             c['descriptor_generator']['type'])
    #: :type: smqtk.algorithms.DescriptorGenerator
    generator = plugin.from_plugin_config(c['descriptor_generator'],
                                          get_descriptor_generator_impls())

    def test_image_load(dfe):
        try:
            PIL.Image.open(io.BytesIO(dfe.get_bytes()))
            return True
        except IOError, ex:
            # noinspection PyProtectedMember
            log.warn(
                "Failed to convert '%s' bytes into an image "
                "(error: %s). Skipping", dfe._filepath, str(ex))
            return False
示例#17
0
def default_config():
    return {
        "descriptor_generator":
            plugin.make_config(get_descriptor_generator_impls()),
        "descriptor_factory": DescriptorElementFactory.get_default_config(),
        "descriptor_index":
            plugin.make_config(get_descriptor_index_impls())
    }
    def test_no_params(self):
        test_params = {}

        factory = DescriptorElementFactory(DummyElementImpl, test_params)

        expected_type = 'type'
        expected_uuid = 'uuid'
        expected_args = ()
        expected_kwds = {}

        # Should construct a new DEI instance under they hood somewhere
        r = factory.new_descriptor(expected_type, expected_uuid)

        ntools.assert_is_instance(r, DummyElementImpl)
        ntools.assert_equal(r._type_label, expected_type)
        ntools.assert_equal(r._uuid, expected_uuid)
        ntools.assert_equal(r.args, expected_args)
        ntools.assert_equal(r.kwds, expected_kwds)
    def test_no_params(self):
        test_params = {}

        factory = DescriptorElementFactory(DummyElementImpl, test_params)

        expected_type = 'type'
        expected_uuid = 'uuid'
        expected_args = ()
        expected_kwds = {}

        # Should construct a new DEI instance under they hood somewhere
        r = factory.new_descriptor(expected_type, expected_uuid)

        ntools.assert_is_instance(r, DummyElementImpl)
        ntools.assert_equal(r._type_label, expected_type)
        ntools.assert_equal(r._uuid, expected_uuid)
        ntools.assert_equal(r.args, expected_args)
        ntools.assert_equal(r.kwds, expected_kwds)
示例#20
0
def get_default_config():
    return {
        "descriptor_factory": DescriptorElementFactory.get_default_config(),
        "descriptor_generator":
        plugin.make_config(get_descriptor_generator_impls),
        "classification_factory":
        ClassificationElementFactory.get_default_config(),
        "classifier": plugin.make_config(get_classifier_impls),
    }
    def test_with_params(self):
        v = numpy.random.randint(0, 10, 10)
        test_params = {'p1': 'some dir', 'vec': v}

        factory = DescriptorElementFactory(DummyElementImpl, test_params)

        ex_type = 'type'
        ex_uuid = 'uuid'
        ex_args = ()
        ex_kwds = test_params
        # Should construct a new DEI instance under they hood somewhere
        r = factory.new_descriptor(ex_type, ex_uuid)

        ntools.assert_is_instance(r, DummyElementImpl)
        ntools.assert_equal(r._type_label, ex_type)
        ntools.assert_equal(r._uuid, ex_uuid)
        ntools.assert_equal(r.args, ex_args)
        ntools.assert_equal(r.kwds, ex_kwds)
示例#22
0
    def from_config(cls, config_dict, type_str, uuid, merge_default=True):
        # convert factory configuration
        config_dict['wrapped_element_factory'] = \
            DescriptorElementFactory.from_config(
                config_dict['wrapped_element_factory']
            )

        return super(CachingDescriptorElement,
                     cls).from_config(config_dict, type_str, uuid,
                                      merge_default)
示例#23
0
def run_file_list(c, filelist_filepath, checkpoint_filepath):
    log = logging.getLogger(__name__)

    file_paths = [l.strip() for l in open(filelist_filepath)]

    log.info("Making memory factory")
    factory = DescriptorElementFactory.from_config(c['descriptor_factory'])

    log.info("Making descriptor generator '%s'",
             c['descriptor_generator']['type'])
    #: :type: smqtk.algorithms.DescriptorGenerator
    generator = from_plugin_config(c['descriptor_generator'],
                                   get_descriptor_generator_impls)
    log.info("Making descriptor generator -- Done")

    valid_file_paths = dict()
    invalid_file_paths = dict()

    def iter_valid_files():
        for fp in file_paths:
            dfe = DataFileElement(fp)
            ct = dfe.content_type()
            if ct in generator.valid_content_types():
                valid_file_paths[fp] = ct
                yield fp
            else:
                invalid_file_paths[fp] = ct

    log.info("Computing descriptors")
    m = compute_many_descriptors(iter_valid_files(),
                                 generator,
                                 factory,
                                 batch_size=256,
                                 )

    # Recording computed file paths and associated file UUIDs (SHA1)
    cf = open(checkpoint_filepath, 'a')
    try:
        for fp, descr in m:
            cf.write("{:s},{:s}\n".format(
                fp, descr.uuid()
            ))
            cf.flush()
    finally:
        cf.close()

    # Output valid file and invalid file dictionaries as pickle
    log.info("Writing valid filepaths map")
    with open('valid_file_map.pickle', 'wb') as f:
        cPickle.dump(valid_file_paths, f)
    log.info("Writing invalid filepaths map")
    with open('invalid_file_map.pickle', 'wb') as f:
        cPickle.dump(invalid_file_paths, f)

    log.info("Done")
示例#24
0
def default_config():
    return {
        "descriptor_generator":
        make_default_config(DescriptorGenerator.get_impls()),
        "descriptor_factory":
        DescriptorElementFactory.get_default_config(),
        "descriptor_set":
        make_default_config(DescriptorSet.get_impls()),
        "optional_data_set":
        make_default_config(DataSet.get_impls())
    }
示例#25
0
def get_default_config():
    return {
        "descriptor_factory":
            DescriptorElementFactory.get_default_config(),
        "descriptor_generator":
            plugin.make_config(get_descriptor_generator_impls()),
        "classification_factory":
            ClassificationElementFactory.get_default_config(),
        "classifier":
            plugin.make_config(get_classifier_impls()),
    }
    def test_with_params(self, dei_init):
        # So we don't break python
        dei_init.return_value = None

        v = numpy.random.randint(0, 10, 10)
        test_params = {
            'p1': 'some dir',
            'vec': v
        }

        factory = DescriptorElementFactory(DummyElementImpl, test_params)

        ex_type = 'type'
        ex_uuid = 'uuid'
        # Should construct a new DEI instance under they hood somewhere
        r = factory.new_descriptor(ex_type, ex_uuid)

        ntools.assert_true(dei_init.called)
        dei_init.assert_called_once_with(ex_type, ex_uuid, p1='some dir', vec=v)
        ntools.assert_is_instance(r, DummyElementImpl)
示例#27
0
def get_default_config():
    return {
        "descriptor_factory":
        DescriptorElementFactory.get_default_config(),
        "descriptor_generator":
        make_default_config(DescriptorGenerator.get_impls()),
        "classification_factory":
        ClassificationElementFactory.get_default_config(),
        "classifier":
        make_default_config(Classifier.get_impls()),
    }
    def test_with_params(self):
        v = numpy.random.randint(0, 10, 10)
        test_params = {
            'p1': 'some dir',
            'vec': v
        }

        factory = DescriptorElementFactory(DummyElementImpl, test_params)

        ex_type = 'type'
        ex_uuid = 'uuid'
        ex_args = ()
        ex_kwds = test_params
        # Should construct a new DEI instance under they hood somewhere
        r = factory.new_descriptor(ex_type, ex_uuid)

        ntools.assert_is_instance(r, DummyElementImpl)
        ntools.assert_equal(r._type_label, ex_type)
        ntools.assert_equal(r._uuid, ex_uuid)
        ntools.assert_equal(r.args, ex_args)
        ntools.assert_equal(r.kwds, ex_kwds)
示例#29
0
    def __init__(self, json_config):
        super(SmqtkClassifierService, self).__init__(json_config)

        self.enable_classifier_removal = \
            bool(json_config[self.CONFIG_ENABLE_CLASSIFIER_REMOVAL])

        self.immutable_labels = set(json_config[self.CONFIG_IMMUTABLE_LABELS])

        # Convert configuration into SMQTK plugin instances.
        #   - Static classifier configurations.
        #       - Skip the example config key
        #   - Classification element factory
        #   - Descriptor generator
        #   - Descriptor element factory
        #   - from-IQR-state classifier configuration
        #       - There must at least be the default key defined for when no
        #         specific classifier type is specified at state POST.

        # Classifier collection + factor
        self.classification_factory = \
            ClassificationElementFactory.from_config(
                json_config[self.CONFIG_CLASSIFICATION_FACTORY]
            )
        #: :type: ClassifierCollection
        self.classifier_collection = ClassifierCollection.from_config(
            json_config[self.CONFIG_CLASSIFIER_COLLECTION]
        )

        # Descriptor generator + factory
        self.descriptor_factory = DescriptorElementFactory.from_config(
            json_config[self.CONFIG_DESCRIPTOR_FACTORY]
        )
        #: :type: smqtk.algorithms.DescriptorGenerator
        self.descriptor_gen = from_config_dict(
            json_config[self.CONFIG_DESCRIPTOR_GENERATOR],
            smqtk.algorithms.DescriptorGenerator.get_impls()
        )

        # Descriptor set bundled for classification-by-UID.
        try:
            self.descriptor_set = from_config_dict(
                json_config.get(self.CONFIG_DESCRIPTOR_SET, {}),
                DescriptorSet.get_impls()
            )
        except ValueError:
            # Default empty set.
            self.descriptor_set = MemoryDescriptorSet()

        # Classifier config for uploaded IQR states.
        self.iqr_state_classifier_config = \
            json_config[self.CONFIG_IQR_CLASSIFIER]

        self.add_routes()
示例#30
0
    def from_config(cls, config_dict, type_str, uuid):
        merged_config = cls.get_default_config()
        merged_config.update(config_dict)

        # convert factory configuration
        merged_config['wrapped_element_factory'] = \
            DescriptorElementFactory.from_config(
                merged_config['wrapped_element_factory']
            )

        return super(CachingDescriptorElement,
                     cls).from_config(merged_config, type_str, uuid)
示例#31
0
    def from_config(cls, config_dict, type_str, uuid):
        merged_config = cls.get_default_config()
        merged_config.update(config_dict)

        # convert factory configuration
        merged_config['wrapped_element_factory'] = \
            DescriptorElementFactory.from_config(
                merged_config['wrapped_element_factory']
            )

        return super(CachingDescriptorElement, cls).from_config(
            merged_config, type_str, uuid
        )
示例#32
0
def run_file_list(c, filelist_filepath, checkpoint_filepath, batch_size=None,
                  check_image=False):
    """
    Top level function handling configuration and inputs/outputs.

    :param c: Configuration dictionary (JSON)
    :type c: dict

    :param filelist_filepath: Path to a text file that lists paths to data
        files, separated by new lines.
    :type filelist_filepath: str

    :param checkpoint_filepath: Output file to which we write input filepath to
        SHA1 (UUID) relationships.
    :type checkpoint_filepath:

    :param batch_size: Optional batch size (None default) of data elements to
        process / descriptors to compute at a time. This causes files and
        stores to be written to incrementally during processing instead of
        one single batch transaction at a time.
    :type batch_size:

    """
    log = logging.getLogger(__name__)

    file_paths = [l.strip() for l in open(filelist_filepath)]

    log.info("Making descriptor factory")
    factory = DescriptorElementFactory.from_config(c['descriptor_factory'])

    log.info("Making descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(c['descriptor_index'],
                                                 get_descriptor_index_impls())

    log.info("Making descriptor generator '%s'",
             c['descriptor_generator']['type'])
    #: :type: smqtk.algorithms.DescriptorGenerator
    generator = plugin.from_plugin_config(c['descriptor_generator'],
                                          get_descriptor_generator_impls())

    def test_image_load(dfe):
        try:
            PIL.Image.open(io.BytesIO(dfe.get_bytes()))
            return True
        except IOError, ex:
            # noinspection PyProtectedMember
            log.warn("Failed to convert '%s' bytes into an image "
                     "(error: %s). Skipping",
                     dfe._filepath, str(ex))
            return False
示例#33
0
def compute_descriptors(task, folderId, dataElementUris, **kwargs):
    """
    Celery task for computing descriptors for a series of data element URIs
    belonging to a single folder.

    After computing descriptors for a series of Girder files, the relevant items
    are updated within Girder to contain the smqtk_uuid (sha1) value as metadata.

    :param task: Celery provided task object.
    :param folderId: The folder these images are related to, this is used for
        namespacing the descriptor index table.
    :param dataElementUris: A list of data element URIs, these are assumed to be
        GirderDataElement URIs.
    """
    task.job_manager.updateProgress(message='Computing descriptors',
                                    forceFlush=True)
    generator = CaffeDescriptorGenerator(
        girderUriFromTask(
            task, getSetting(task.girder_client, 'caffe_network_prototxt')),
        girderUriFromTask(
            task, getSetting(task.girder_client, 'caffe_network_model')),
        girderUriFromTask(task,
                          getSetting(task.girder_client, 'caffe_image_mean')))

    factory = DescriptorElementFactory(
        PostgresDescriptorElement, {
            'db_name': getSetting(task.girder_client, 'db_name'),
            'db_host': getSetting(task.girder_client, 'db_host'),
            'db_user': getSetting(task.girder_client, 'db_user'),
            'db_pass': getSetting(task.girder_client, 'db_pass')
        })

    index = descriptorIndexFromFolderId(task.girder_client, folderId)

    valid_elements = iter_valid_elements([x[1] for x in dataElementUris],
                                         generator.valid_content_types())

    descriptors = compute_functions.compute_many_descriptors(valid_elements,
                                                             generator,
                                                             factory,
                                                             index,
                                                             use_mp=False)

    fileToItemId = dict([(y.split('/')[-1], x) for x, y in dataElementUris])

    for de, descriptor in descriptors:
        # TODO Catch errors that could occur here
        with task.girder_client.session():
            task.girder_client.addMetadataToItem(
                fileToItemId[de.file_id], {'smqtk_uuid': descriptor.uuid()})
示例#34
0
def run_file_list(c, filelist_filepath, checkpoint_filepath, batch_size):
    log = logging.getLogger(__name__)

    file_paths = [l.strip() for l in open(filelist_filepath)]

    log.info("Making descriptor factory")
    factory = DescriptorElementFactory.from_config(c["descriptor_factory"])

    log.info("Making descriptor generator '%s'", c["descriptor_generator"]["type"])
    #: :type: smqtk.algorithms.DescriptorGenerator
    generator = from_plugin_config(c["descriptor_generator"], get_descriptor_generator_impls)
    log.info("Making descriptor generator -- Done")

    valid_file_paths = dict()
    invalid_file_paths = dict()

    def iter_valid_elements():
        for fp in file_paths:
            dfe = DataFileElement(fp)
            ct = dfe.content_type()
            if ct in generator.valid_content_types():
                valid_file_paths[fp] = ct
                yield dfe
            else:
                invalid_file_paths[fp] = ct

    log.info("Computing descriptors")
    m = compute_many_descriptors(iter_valid_elements(), generator, factory, batch_size=batch_size)

    # Recording computed file paths and associated file UUIDs (SHA1)
    cf = open(checkpoint_filepath, "a")
    try:
        for fp, descr in m:
            cf.write("{:s},{:s}\n".format(fp, descr.uuid()))
            cf.flush()
    finally:
        cf.close()

    # Output valid file and invalid file dictionaries as pickle
    log.info("Writing valid filepaths map")
    with open("file_map.valid.pickle", "wb") as f:
        cPickle.dump(valid_file_paths, f)
    log.info("Writing invalid filepaths map")
    with open("file_map.invalid.pickle", "wb") as f:
        cPickle.dump(invalid_file_paths, f)

    log.info("Done")
示例#35
0
def train_classifier_iqr(config, iqr_state_fp):
    #: :type: smqtk.algorithms.SupervisedClassifier
    classifier = from_config_dict(config['classifier'],
                                  SupervisedClassifier.get_impls())

    # Load state into an empty IqrSession instance.
    with open(iqr_state_fp, 'rb') as f:
        state_bytes = f.read().strip()
    descr_factory = DescriptorElementFactory(DescriptorMemoryElement, {})
    iqrs = IqrSession()
    iqrs.set_state_bytes(state_bytes, descr_factory)

    # Positive descriptor examples for training are composed of those from
    # external and internal sets. Same for negative descriptor examples.
    pos = iqrs.positive_descriptors | iqrs.external_positive_descriptors
    neg = iqrs.negative_descriptors | iqrs.external_negative_descriptors
    classifier.train(class_examples={'positive': pos, 'negative': neg})
示例#36
0
文件: __init__.py 项目: Kitware/SMQTK
    def get_default_config(cls):
        """
        Generate and return a default configuration dictionary for this class.
        This will be primarily used for generating what the configuration
        dictionary would look like for this class without instantiating it.

        :return: Default configuration dictionary for the class.
        :rtype: dict

        """
        c = super(DescriptorServiceServer, cls).get_default_config()
        merge_dict(c, {
            "descriptor_factory": DescriptorElementFactory.get_default_config(),
            "descriptor_generators": {
                "example": plugin.make_config(get_descriptor_generator_impls())
            }
        })
        return c
示例#37
0
    def get_default_config(cls):
        """
        Generate and return a default configuration dictionary for this class.
        This will be primarily used for generating what the configuration
        dictionary would look like for this class without instantiating it.

        :return: Default configuration dictionary for the class.
        :rtype: dict

        """
        c = super(DescriptorServiceServer, cls).get_default_config()
        merge_configs(c, {
            "descriptor_factory": DescriptorElementFactory.get_default_config(),
            "descriptor_generators": {
                "example": plugin.make_config(get_descriptor_generator_impls)
            }
        })
        return c
    def test_call(self):
        # Same as `test_with_params` but using __call__ entry point
        v = numpy.random.randint(0, 10, 10)
        test_params = {'p1': 'some dir', 'vec': v}

        factory = DescriptorElementFactory(DummyElementImpl, test_params)

        ex_type = 'type'
        ex_uuid = 'uuid'
        ex_args = ()
        ex_kwds = test_params
        # Should construct a new DEI instance under they hood somewhere
        r = factory(ex_type, ex_uuid)

        self.assertIsInstance(r, DummyElementImpl)
        self.assertEqual(r._type_label, ex_type)
        self.assertEqual(r._uuid, ex_uuid)
        self.assertEqual(r.args, ex_args)
        self.assertEqual(r.kwds, ex_kwds)
示例#39
0
    def __init__(self, json_config):
        super(SmqtkClassifierService, self).__init__(json_config)

        self.enable_classifier_removal = \
            bool(json_config[self.CONFIG_ENABLE_CLASSIFIER_REMOVAL])

        self.immutable_labels = set(json_config[self.CONFIG_IMMUTABLE_LABELS])

        # Convert configuration into SMQTK plugin instances.
        #   - Static classifier configurations.
        #       - Skip the example config key
        #   - Classification element factory
        #   - Descriptor generator
        #   - Descriptor element factory
        #   - from-IQR-state classifier configuration
        #       - There must at least be the default key defined for when no
        #         specific classifier type is specified at state POST.

        # Classifier collection + factor
        self.classification_factory = \
            ClassificationElementFactory.from_config(
                json_config[self.CONFIG_CLASSIFICATION_FACTORY]
            )
        self.classifier_collection = ClassifierCollection.from_config(
            json_config[self.CONFIG_CLASSIFIER_COLLECTION]
        )

        # Descriptor generator + factory
        self.descriptor_factory = DescriptorElementFactory.from_config(
            json_config[self.CONFIG_DESCRIPTOR_FACTORY]
        )
        #: :type: smqtk.algorithms.DescriptorGenerator
        self.descriptor_gen = smqtk.utils.plugin.from_plugin_config(
            json_config[self.CONFIG_DESCRIPTOR_GENERATOR],
            smqtk.algorithms.get_descriptor_generator_impls()
        )

        # Classifier config for uploaded IQR states.
        self.iqr_state_classifier_config = \
            json_config[self.CONFIG_IQR_CLASSIFIER]

        self.add_routes()
示例#40
0
文件: __init__.py 项目: Kitware/SMQTK
    def __init__(self, json_config):
        """
        Initialize application based of supplied JSON configuration

        :param json_config: JSON configuration dictionary
        :type json_config: dict

        """
        super(NearestNeighborServiceServer, self).__init__(json_config)

        self.update_index = json_config['update_descriptor_index']

        # Descriptor factory setup
        self._log.info("Initializing DescriptorElementFactory")
        self.descr_elem_factory = DescriptorElementFactory.from_config(
            self.json_config['descriptor_factory']
        )

        #: :type: smqtk.representation.DescriptorIndex | None
        self.descr_index = None
        if self.update_index:
            self._log.info("Initializing DescriptorIndex to update")
            #: :type: smqtk.representation.DescriptorIndex | None
            self.descr_index = plugin.from_plugin_config(
                json_config['descriptor_index'],
                get_descriptor_index_impls()
            )

        #: :type: smqtk.algorithms.NearestNeighborsIndex
        self.nn_index = plugin.from_plugin_config(
            json_config['nn_index'],
            get_nn_index_impls()
        )

        #: :type: smqtk.algorithms.DescriptorGenerator
        self.descriptor_generator_inst = plugin.from_plugin_config(
            self.json_config['descriptor_generator'],
            get_descriptor_generator_impls()
        )

        @self.route("/count", methods=['GET'])
        def count():
            """
            Return the number of elements represented in this index.
            """
            return flask.jsonify(**{
                "count": self.nn_index.count(),
            })

        @self.route("/compute/<path:uri>", methods=["POST"])
        def compute(uri):
            """
            Compute the descriptor for a URI specified data element using the
            configured descriptor generator.

            See ``compute_nearest_neighbors`` method docstring for URI
            specifications accepted.

            If the a descriptor index was configured and update was turned on,
            we add the computed descriptor to the index.

            JSON Return format::
                {
                    "success": <bool>

                    "message": <str>

                    "descriptor": <None|list[float]>

                    "reference_uri": <str>
                }

            :param uri: URI data specification.

            """
            descriptor = None
            try:
                descriptor = self.generate_descriptor_for_uri(uri)
                message = "Descriptor generated"
                descriptor = list(map(float, descriptor.vector()))
            except ValueError as ex:
                message = "Input value issue: %s" % str(ex)
            except RuntimeError as ex:
                message = "Descriptor extraction failure: %s" % str(ex)

            return flask.jsonify(
                success=descriptor is not None,
                message=message,
                descriptor=descriptor,
                reference_uri=uri,
            )

        @self.route("/nn/<path:uri>")
        @self.route("/nn/n=<int:n>/<path:uri>")
        @self.route("/nn/n=<int:n>/<int:start_i>:<int:end_i>/<path:uri>")
        def compute_nearest_neighbors(uri, n=10, start_i=None, end_i=None):
            """
            Data modes for upload/use:

                - local filepath
                - base64
                - http/s URL
                - existing data/descriptor UUID

            The following sub-sections detail how different URI's can be used.

            Local Filepath
            --------------
            The URI string must be prefixed with ``file://``, followed by the
            full path to the data file to describe.

            Base 64 data
            ------------
            The URI string must be prefixed with "base64://", followed by the
            base64 encoded string. This mode also requires an additional
            ``?content_type=`` to provide data content type information. This
            mode saves the encoded data to temporary file for processing.

            HTTP/S address
            --------------
            This is the default mode when the URI prefix is none of the above.
            This uses the requests module to locally download a data file
            for processing.

            Existing Data/Descriptor by UUID
            --------------------------------
            When given a uri prefixed with "uuid://", we interpret the remainder
            of the uri as the UUID of a descriptor already present in the
            configured descriptor index. If the given UUID is not present in the
            index, a KeyError is raised.

            JSON Return format
            ------------------
                {
                    "success": <bool>

                    "message": <str>

                    "neighbors": <None|list[float]>

                    "reference_uri": <str>
                }

            :param n: Number of neighbors to query for
            :param start_i: The starting index of the neighbor vectors to slice
                into for return.
            :param end_i: The ending index of the neighbor vectors to slice
                into for return.
            :type uri: str

            """
            descriptor = None
            try:
                descriptor = self.generate_descriptor_for_uri(uri)
                message = "descriptor computed"
            except ValueError as ex:
                message = "Input data issue: %s" % str(ex)
            except RuntimeError as ex:
                message = "Descriptor generation failure: %s" % str(ex)

            # Base pagination slicing based on provided start and end indices,
            # otherwise clamp to beginning/ending of queried neighbor sequence.
            page_slice = slice(start_i or 0, end_i or n)
            neighbors = []
            dists = []
            if descriptor is not None:
                try:
                    neighbors, dists = \
                        self.nn_index.nn(descriptor, n)
                except ValueError as ex:
                    message = "Descriptor or index related issue: %s" % str(ex)

            # TODO: Return the optional descriptor vectors for the neighbors
            # noinspection PyTypeChecker
            d = {
                "success": bool(descriptor is not None),
                "message": message,
                "neighbors": [n.uuid() for n in neighbors[page_slice]],
                "distances": dists[page_slice],
                "reference_uri": uri
            }
            return flask.jsonify(d)
示例#41
0
from smqtk.representation import DescriptorElementFactory
from smqtk.utils.bin_utils import logging, initialize_logging
from smqtk.utils.jsmin import jsmin

from load_algo import load_algo


if not logging.getLogger().handlers:
    initialize_logging(logging.getLogger(), logging.DEBUG)
log = logging.getLogger(__name__)


log.info("Loading descriptor elements")
d_type_str = open("descriptor_type_name.txt").read().strip()
df_config = json.loads(jsmin(open('descriptor_factory_config.json').read()))
factory = DescriptorElementFactory.from_config(df_config)

#
# Sample code for finding non-NaN descriptors in parallel
#
# def add_non_nan_uuid(uuid):
#     d = factory.new_descriptor(d_type_str, uuid)
#     if d.vector().sum() > 0:
#         return uuid
#     return None
#
# import multiprocessing
# p = multiprocessing.Pool()
# non_nan_uuids = \
#     p.map(add_non_nan_uuid,
#           (l.strip() for l in open('descriptor_uuids.txt')))
示例#42
0
文件: __init__.py 项目: Kitware/SMQTK
    def __init__(self, json_config):
        """
        Initialize application based of supplied JSON configuration

        :param json_config: JSON configuration dictionary
        :type json_config: dict

        """
        super(DescriptorServiceServer, self).__init__(json_config)

        # Descriptor factory setup
        self._log.info("Initializing DescriptorElementFactory")
        self.descr_elem_factory = DescriptorElementFactory.from_config(
            self.json_config['descriptor_factory']
        )

        # Descriptor generator configuration labels
        #: :type: dict[str, dict]
        self.generator_label_configs = self.json_config['descriptor_generators']

        # Cache of DescriptorGenerator instances so we don't have to
        # continuously initialize them as we get requests.
        self.descriptor_cache = {}
        self.descriptor_cache_lock = multiprocessing.RLock()

        @self.route("/")
        def list_ingest_labels():
            return flask.jsonify({
                "labels": sorted(self.generator_label_configs)
            })

        @self.route("/all/content_types")
        def all_content_types():
            """
            Of available descriptors, what content types are processable, and
            what types are associated to which available descriptor generator.
            """
            all_types = set()
            # Mapping of configuration label to content types that generator
            # can handle
            r = {}
            for l in self.generator_label_configs:
                d = self.get_descriptor_inst(l)
                all_types.update(d.valid_content_types())
                r[l] = sorted(d.valid_content_types())

            return flask.jsonify({
                "all": sorted(all_types),
                "by-label": r
            })

        @self.route("/all/compute/<path:uri>")
        def all_compute(uri):
            """
            Compute descriptors over the specified content for all generators
            that function over the data's content type.

            JSON Return format::

                {
                    "success": <bool>

                    "content_type": <str>

                    "message": <str>

                    "descriptors": {  "<label>":  <list[float]>, ... } | None

                    "reference_uri": <str>
                }

            """
            message = "execution nominal"

            data_elem = None
            try:
                data_elem = self.resolve_data_element(uri)
            except ValueError as ex:
                message = "Failed URI resolution: %s" % str(ex)

            descriptors = {}
            finished_loop = False
            if data_elem:
                for l in self.generator_label_configs:
                    if data_elem.content_type() in \
                            self.get_descriptor_inst(l).valid_content_types():
                        d = None
                        try:
                            d = self.generate_descriptor(data_elem, l)
                        except RuntimeError as ex:
                            message = "Descriptor extraction failure: %s" \
                                      % str(ex)
                        except ValueError as ex:
                            message = "Data content type issue: %s" % str(ex)

                        descriptors[l] = d and d.vector().tolist()
                if not descriptors:
                    message = "No descriptors can handle URI content type: %s" \
                              % data_elem.content_type
                else:
                    finished_loop = True

            return flask.jsonify({
                "success": finished_loop,
                "content_type": data_elem.content_type(),
                "message": message,
                "descriptors": descriptors,
                "reference_uri": uri
            })

        @self.route("/<string:descriptor_label>/<path:uri>")
        def compute_descriptor(descriptor_label, uri):
            """

            Data modes for upload/use::

                - local filepath
                - base64
                - http/s URL

            The following sub-sections detail how different URI's can be used.

            Local Filepath
            --------------

            The URI string must be prefixed with ``file://``, followed by the
            full path to the data file to describe.

            Base 64 data
            ------------

            The URI string must be prefixed with "base64://", followed by the
            base64 encoded string. This mode also requires an additional
            ``?content_type=`` to provide data content type information. This
            mode saves the encoded data to temporary file for processing.

            HTTP/S address
            --------------

            This is the default mode when the URI prefix is none of the above.
            This uses the requests module to locally download a data file
            for processing.

            JSON Return format::

                {
                    "success": <bool>

                    "message": <str>

                    "descriptor": <None|list[float]>

                    "reference_uri": <str>
                }

            :type descriptor_label: str
            :type uri: str

            """
            message = "execution nominal"
            descriptor = None

            de = None
            try:
                de = self.resolve_data_element(uri)
            except ValueError as ex:
                message = "URI resolution issue: %s" % str(ex)

            if de:
                try:
                    descriptor = self.generate_descriptor(de, descriptor_label)
                except RuntimeError as ex:
                    message = "Descriptor extraction failure: %s" % str(ex)
                except ValueError as ex:
                    message = "Data content type issue: %s" % str(ex)

            return flask.jsonify({
                "success": descriptor is not None,
                "message": message,
                "descriptor":
                    (descriptor is not None and descriptor.vector().tolist())
                    or None,
                "reference_uri": uri
            })
示例#43
0
def main():
    parser = cli_parser()
    args = parser.parse_args()

    output_filepath = args.output_filepath
    overwrite = args.overwrite
    verbose = args.verbose

    llevel = logging.DEBUG if verbose else logging.INFO
    bin_utils.initialize_logging(logging.getLogger(), llevel)
    log = logging.getLogger("main")

    # Merge loaded config with default
    config_loaded = False
    config = default_config()
    if args.config:
        if os.path.isfile(args.config):
            with open(args.config, 'r') as f:
                config.update(json.load(f))
            config_loaded = True
        elif not os.path.isfile(args.config):
            log.error("Configuration file path not valid.")
            exit(1)

    bin_utils.output_config(args.output_config, config, log, True)

    # Configuration must have been loaded at this point since we can't normally
    # trust the default.
    if not config_loaded:
        log.error("No configuration provided")
        exit(1)

    if not args.input_file:
        log.error("Failed to provide an input file path")
        exit(1)
    elif not os.path.isfile(args.input_file):
        log.error("Given path does not point to a file.")
        exit(1)

    input_filepath = args.input_file
    data_element = DataFileElement(input_filepath)

    factory = DescriptorElementFactory.from_config(config['descriptor_factory'])
    #: :type: smqtk.algorithms.descriptor_generator.DescriptorGenerator
    cd = plugin.from_plugin_config(config['content_descriptor'],
                                   get_descriptor_generator_impls())
    descr_elem = cd.compute_descriptor(data_element, factory, overwrite)
    vec = descr_elem.vector()

    if vec is None:
        log.error("Failed to generate a descriptor vector for the input data!")

    if output_filepath:
        numpy.save(output_filepath, vec)
    else:
        # Construct string, because numpy
        s = []
        # noinspection PyTypeChecker
        for f in vec:
            s.append('%15f' % f)
        print ' '.join(s)
示例#44
0
def classify_files(config, label, file_globs):
    log = logging.getLogger(__name__)

    #: :type: smqtk.algorithms.Classifier
    classifier = \
        plugin.from_plugin_config(config['classifier'],
                                  get_classifier_impls())

    def log_avaialable_labels():
        log.info("Available classifier labels:")
        for l in classifier.get_labels():
            log.info("- %s", l)

    if label is None:
        log_avaialable_labels()
        return
    elif label not in classifier.get_labels():
        log.error(
            "Invalid classification label provided to compute and filter "
            "on: '%s'", label)
        log_avaialable_labels()
        return

    log.info("Collecting files from globs")
    #: :type: list[DataFileElement]
    data_elements = []
    uuid2filepath = {}
    for g in file_globs:
        if os.path.isfile(g):
            d = DataFileElement(g)
            data_elements.append(d)
            uuid2filepath[d.uuid()] = g
        else:
            log.debug("expanding glob: %s", g)
            for fp in glob.iglob(g):
                d = DataFileElement(fp)
                data_elements.append(d)
                uuid2filepath[d.uuid()] = fp
    if not data_elements:
        raise RuntimeError("No files provided for classification.")

    log.info("Computing descriptors")
    descriptor_factory = \
        DescriptorElementFactory.from_config(config['descriptor_factory'])
    #: :type: smqtk.algorithms.DescriptorGenerator
    descriptor_generator = \
        plugin.from_plugin_config(config['descriptor_generator'],
                                  get_descriptor_generator_impls())
    descr_map = descriptor_generator\
        .compute_descriptor_async(data_elements, descriptor_factory)

    log.info("Classifying descriptors")
    classification_factory = ClassificationElementFactory \
        .from_config(config['classification_factory'])
    classification_map = classifier\
        .classify_async(descr_map.values(), classification_factory)

    log.info("Printing input file paths that classified as the given label.")
    # map of UUID to filepath:
    uuid2c = dict((c.uuid, c) for c in classification_map.itervalues())
    for data in data_elements:
        if uuid2c[data.uuid()].max_label() == label:
            print uuid2filepath[data.uuid()]
示例#45
0
        def test_no_save_model_pickle(self):
            # Test model preservation across pickling even without model cache
            # file paths set.
            classifier = LibSvmClassifier(
                train_params={
                    '-t': 0,  # linear kernel
                    '-b': 1,  # enable probability estimates
                    '-c': 2,  # SVM-C parameter C
                    '-q': '',  # quite mode
                },
                normalize=None,  # DO NOT normalize descriptors
            )
            ntools.assert_true(classifier.svm_model is None)
            # Empty model should not trigger __LOCAL__ content in pickle
            ntools.assert_not_in('__LOCAL__', classifier.__getstate__())
            _ = cPickle.loads(cPickle.dumps(classifier))

            # train arbitrary model (same as ``test_simple_classification``)
            DIM = 2
            N = 1000
            POS_LABEL = 'positive'
            NEG_LABEL = 'negative'
            d_factory = DescriptorElementFactory(DescriptorMemoryElement, {})
            c_factory = ClassificationElementFactory(
                MemoryClassificationElement, {})

            def make_element(argtup):
                (i, v) = argtup
                d = d_factory.new_descriptor('test', i)
                d.set_vector(v)
                return d

            # Constructing artificial descriptors
            x = numpy.random.rand(N, DIM)
            x_pos = x[x[:, 1] <= 0.45]
            x_neg = x[x[:, 1] >= 0.55]
            p = multiprocessing.pool.ThreadPool()
            d_pos = p.map(make_element, enumerate(x_pos))
            d_neg = p.map(make_element, enumerate(x_neg, start=N // 2))
            p.close()
            p.join()

            # Training
            classifier.train({POS_LABEL: d_pos, NEG_LABEL: d_neg})

            # Test original classifier
            t_v = numpy.random.rand(DIM)
            t = d_factory.new_descriptor('query', 0)
            t.set_vector(t_v)
            c_expected = classifier.classify(t, c_factory)

            # Should see __LOCAL__ content in pickle state now
            p_state = classifier.__getstate__()
            ntools.assert_in('__LOCAL__', p_state)
            ntools.assert_in('__LOCAL_LABELS__', p_state)
            ntools.assert_in('__LOCAL_MODEL__', p_state)
            ntools.assert_true(len(p_state['__LOCAL_LABELS__']) > 0)
            ntools.assert_true(len(p_state['__LOCAL_MODEL__']) > 0)

            # Restored classifier should classify the same test descriptor the
            # same
            #: :type: LibSvmClassifier
            classifier2 = cPickle.loads(cPickle.dumps(classifier))
            c_post_pickle = classifier2.classify(t, c_factory)
            # There may be floating point error, so extract actual confidence
            # values and check post round
            c_pp_positive = c_post_pickle[POS_LABEL]
            c_pp_negative = c_post_pickle[NEG_LABEL]
            c_e_positive = c_expected[POS_LABEL]
            c_e_negative = c_expected[NEG_LABEL]
            ntools.assert_almost_equal(c_e_positive, c_pp_positive, 5)
            ntools.assert_almost_equal(c_e_negative, c_pp_negative, 5)
示例#46
0
        def test_simple_classification(self):
            """
            simple LibSvmClassifier test - 2-class

            Test libSVM classification functionality using random constructed
            data, training the y=0.5 split
            """
            DIM = 2
            N = 1000
            POS_LABEL = 'positive'
            NEG_LABEL = 'negative'
            p = multiprocessing.pool.ThreadPool()
            d_factory = DescriptorElementFactory(DescriptorMemoryElement, {})
            c_factory = ClassificationElementFactory(
                MemoryClassificationElement, {})

            def make_element(argtup):
                (i, v) = argtup
                d = d_factory.new_descriptor('test', i)
                d.set_vector(v)
                return d

            # Constructing artificial descriptors
            x = numpy.random.rand(N, DIM)
            x_pos = x[x[:, 1] <= 0.45]
            x_neg = x[x[:, 1] >= 0.55]

            d_pos = p.map(make_element, enumerate(x_pos))
            d_neg = p.map(make_element, enumerate(x_neg, start=N // 2))

            # Create/Train test classifier
            classifier = LibSvmClassifier(
                train_params={
                    '-t': 0,  # linear kernel
                    '-b': 1,  # enable probability estimates
                    '-c': 2,  # SVM-C parameter C
                    '-q': '',  # quite mode
                },
                normalize=None,  # DO NOT normalize descriptors
            )
            classifier.train({POS_LABEL: d_pos, NEG_LABEL: d_neg})

            # Test classifier
            x = numpy.random.rand(N, DIM)
            x_pos = x[x[:, 1] <= 0.45]
            x_neg = x[x[:, 1] >= 0.55]

            d_pos = p.map(make_element, enumerate(x_pos, N))
            d_neg = p.map(make_element, enumerate(x_neg, N + N // 2))

            d_pos_sync = {}  # for comparing to async
            for d in d_pos:
                c = classifier.classify(d, c_factory)
                ntools.assert_equal(
                    c.max_label(), POS_LABEL,
                    "Found False positive: %s :: %s" %
                    (d.vector(), c.get_classification()))
                d_pos_sync[d] = c

            d_neg_sync = {}
            for d in d_neg:
                c = classifier.classify(d, c_factory)
                ntools.assert_equal(
                    c.max_label(), NEG_LABEL,
                    "Found False negative: %s :: %s" %
                    (d.vector(), c.get_classification()))
                d_neg_sync[d] = c

            # test that async classify produces the same results
            # -- d_pos
            m_pos = classifier.classify_async(d_pos, c_factory)
            ntools.assert_equal(
                m_pos, d_pos_sync,
                "Async computation of pos set did not yield "
                "the same results as synchronous "
                "classification.")
            # -- d_neg
            m_neg = classifier.classify_async(d_neg, c_factory)
            ntools.assert_equal(
                m_neg, d_neg_sync,
                "Async computation of neg set did not yield "
                "the same results as synchronous "
                "classification.")
            # -- combined -- threaded
            combined_truth = dict(d_pos_sync.items())
            combined_truth.update(d_neg_sync)
            m_combined = classifier.classify_async(
                d_pos + d_neg,
                c_factory,
                use_multiprocessing=False,
            )
            ntools.assert_equal(
                m_combined, combined_truth,
                "Async computation of all test descriptors "
                "did not yield the same results as "
                "synchronous classification.")
            # -- combined -- multiprocess
            m_combined = classifier.classify_async(
                d_pos + d_neg,
                c_factory,
                use_multiprocessing=True,
            )
            ntools.assert_equal(
                m_combined, combined_truth,
                "Async computation of all test descriptors "
                "(mixed order) did not yield the same results "
                "as synchronous classification.")

            # Closing resources
            p.close()
            p.join()
示例#47
0
import abc
import numpy

from smqtk.algorithms import SmqtkAlgorithm
from smqtk.representation import DescriptorElementFactory
from smqtk.representation.descriptor_element.local_elements import \
    DescriptorMemoryElement
from smqtk.utils import ContentTypeValidator
from smqtk.utils.parallel import parallel_map


DFLT_DESCRIPTOR_FACTORY = DescriptorElementFactory(DescriptorMemoryElement, {})


class DescriptorGenerator (SmqtkAlgorithm, ContentTypeValidator):
    """
    Base abstract Feature Descriptor interface
    """

    def compute_descriptor(self, data, descr_factory=DFLT_DESCRIPTOR_FACTORY,
                           overwrite=False):
        """
        Given some data, return a descriptor element containing a descriptor
        vector.

        :raises RuntimeError: Descriptor extraction failure of some kind.
        :raises ValueError: Given data element content was not of a valid type
            with respect to this descriptor.

        :param data: Some kind of input data for the feature descriptor.
        :type data: smqtk.representation.DataElement
示例#48
0
        def test_simple_multiclass_classification(self):
            """
            simple LibSvmClassifier test - 3-class

            Test libSVM classification functionality using random constructed
            data, training the y=0.33 and y=.66 split
            """
            DIM = 2
            N = 1000
            P1_LABEL = 'p1'
            P2_LABEL = 'p2'
            P3_LABEL = 'p3'
            p = multiprocessing.pool.ThreadPool()
            d_factory = DescriptorElementFactory(DescriptorMemoryElement, {})
            c_factory = ClassificationElementFactory(
                MemoryClassificationElement, {})
            di = 0

            def make_element(argtup):
                (i, v) = argtup
                d = d_factory.new_descriptor('test', i)
                d.set_vector(v)
                return d

            # Constructing artificial descriptors
            x = numpy.random.rand(N, DIM)
            x_p1 = x[x[:, 1] <= 0.30]
            x_p2 = x[(x[:, 1] >= 0.36) & (x[:, 1] <= 0.63)]
            x_p3 = x[x[:, 1] >= 0.69]

            d_p1 = p.map(make_element, enumerate(x_p1, di))
            di += len(d_p1)
            d_p2 = p.map(make_element, enumerate(x_p2, di))
            di += len(d_p2)
            d_p3 = p.map(make_element, enumerate(x_p3, di))
            di += len(d_p3)

            # Create/Train test classifier
            classifier = LibSvmClassifier(
                train_params={
                    '-t': 0,  # linear kernel
                    '-b': 1,  # enable probability estimates
                    '-c': 2,  # SVM-C parameter C
                    '-q': ''  # quite mode
                },
                normalize=None,  # DO NOT normalize descriptors
            )
            classifier.train({P1_LABEL: d_p1, P2_LABEL: d_p2, P3_LABEL: d_p3})

            # Test classifier
            x = numpy.random.rand(N, DIM)
            x_p1 = x[x[:, 1] <= 0.30]
            x_p2 = x[(x[:, 1] >= 0.36) & (x[:, 1] <= 0.63)]
            x_p3 = x[x[:, 1] >= 0.69]

            d_p1 = p.map(make_element, enumerate(x_p1, di))
            di += len(d_p1)
            d_p2 = p.map(make_element, enumerate(x_p2, di))
            di += len(d_p2)
            d_p3 = p.map(make_element, enumerate(x_p3, di))
            di += len(d_p3)

            d_p1_sync = {}
            for d in d_p1:
                c = classifier.classify(d, c_factory)
                ntools.assert_equal(
                    c.max_label(), P1_LABEL, "Incorrect %s label: %s :: %s" %
                    (P1_LABEL, d.vector(), c.get_classification()))
                d_p1_sync[d] = c

            d_p2_sync = {}
            for d in d_p2:
                c = classifier.classify(d, c_factory)
                ntools.assert_equal(
                    c.max_label(), P2_LABEL, "Incorrect %s label: %s :: %s" %
                    (P2_LABEL, d.vector(), c.get_classification()))
                d_p2_sync[d] = c

            d_neg_sync = {}
            for d in d_p3:
                c = classifier.classify(d, c_factory)
                ntools.assert_equal(
                    c.max_label(), P3_LABEL, "Incorrect %s label: %s :: %s" %
                    (P3_LABEL, d.vector(), c.get_classification()))
                d_neg_sync[d] = c

            # test that async classify produces the same results
            # -- p1
            async_p1 = classifier.classify_async(d_p1, c_factory)
            ntools.assert_equal(
                async_p1, d_p1_sync,
                "Async computation of p1 set did not yield "
                "the same results as synchronous computation.")
            # -- p2
            async_p2 = classifier.classify_async(d_p2, c_factory)
            ntools.assert_equal(
                async_p2, d_p2_sync,
                "Async computation of p2 set did not yield "
                "the same results as synchronous computation.")
            # -- neg
            async_neg = classifier.classify_async(d_p3, c_factory)
            ntools.assert_equal(
                async_neg, d_neg_sync,
                "Async computation of neg set did not yield "
                "the same results as synchronous computation.")
            # -- combined -- threaded
            sync_combined = dict(d_p1_sync.items())
            sync_combined.update(d_p2_sync)
            sync_combined.update(d_neg_sync)
            async_combined = classifier.classify_async(
                d_p1 + d_p2 + d_p3, c_factory, use_multiprocessing=False)
            ntools.assert_equal(
                async_combined, sync_combined,
                "Async computation of all test descriptors "
                "did not yield the same results as "
                "synchronous classification.")
            # -- combined -- multiprocess
            async_combined = classifier.classify_async(
                d_p1 + d_p2 + d_p3, c_factory, use_multiprocessing=True)
            ntools.assert_equal(
                async_combined, sync_combined,
                "Async computation of all test descriptors "
                "(mixed order) did not yield the same results "
                "as synchronous classification.")

            # Closing resources
            p.close()
            p.join()
示例#49
0
def classify_files(config, label, file_globs):
    log = logging.getLogger(__name__)

    #: :type: smqtk.algorithms.Classifier
    classifier = \
        plugin.from_plugin_config(config['classifier'],
                                  get_classifier_impls())

    def log_avaialable_labels():
        log.info("Available classifier labels:")
        for l in classifier.get_labels():
            log.info("- %s", l)

    if label is None:
        log_avaialable_labels()
        return
    elif label not in classifier.get_labels():
        log.error("Invalid classification label provided to compute and filter "
                  "on: '%s'", label)
        log_avaialable_labels()
        return

    log.info("Collecting files from globs")
    #: :type: list[DataFileElement]
    data_elements = []
    uuid2filepath = {}
    for g in file_globs:
        if os.path.isfile(g):
            d = DataFileElement(g)
            data_elements.append(d)
            uuid2filepath[d.uuid()] = g
        else:
            log.debug("expanding glob: %s", g)
            for fp in glob.iglob(g):
                d = DataFileElement(fp)
                data_elements.append(d)
                uuid2filepath[d.uuid()] = fp
    if not data_elements:
        raise RuntimeError("No files provided for classification.")

    log.info("Computing descriptors")
    descriptor_factory = \
        DescriptorElementFactory.from_config(config['descriptor_factory'])
    #: :type: smqtk.algorithms.DescriptorGenerator
    descriptor_generator = \
        plugin.from_plugin_config(config['descriptor_generator'],
                                  get_descriptor_generator_impls())
    descr_map = descriptor_generator\
        .compute_descriptor_async(data_elements, descriptor_factory)

    log.info("Classifying descriptors")
    classification_factory = ClassificationElementFactory \
        .from_config(config['classification_factory'])
    classification_map = classifier\
        .classify_async(list(descr_map.values()), classification_factory)

    log.info("Printing input file paths that classified as the given label.")
    # map of UUID to filepath:
    uuid2c = dict((c.uuid, c) for c in six.itervalues(classification_map))
    for data in data_elements:
        d_uuid = data.uuid()
        log.debug("'{}' classification map: {}".format(
            uuid2filepath[d_uuid], uuid2c[d_uuid].get_classification()
        ))
        if uuid2c[d_uuid].max_label() == label:
            print(uuid2filepath[d_uuid])
示例#50
0
    def __init__(self, json_config):
        """
        Initialize application based of supplied JSON configuration

        :param json_config: JSON configuration dictionary
        :type json_config: dict

        """
        super(DescriptorServiceServer, self).__init__(json_config)

        # Descriptor factory setup
        self._log.info("Initializing DescriptorElementFactory")
        self.descr_elem_factory = DescriptorElementFactory.from_config(
            self.json_config['descriptor_factory'])

        # Descriptor generator configuration labels
        #: :type: dict[str, dict]
        self.generator_label_configs = self.json_config[
            'descriptor_generators']

        # Cache of DescriptorGenerator instances so we don't have to
        # continuously initialize them as we get requests.
        self.descriptor_cache = {}
        self.descriptor_cache_lock = multiprocessing.RLock()

        @self.route("/")
        def list_ingest_labels():
            return flask.jsonify(
                {"labels": sorted(self.generator_label_configs)})

        @self.route("/all/content_types")
        def all_content_types():
            """
            Of available descriptors, what content types are processable, and
            what types are associated to which available descriptor generator.
            """
            all_types = set()
            # Mapping of configuration label to content types that generator
            # can handle
            r = {}
            for l in self.generator_label_configs:
                d = self.get_descriptor_inst(l)
                all_types.update(d.valid_content_types())
                r[l] = sorted(d.valid_content_types())

            return flask.jsonify({"all": sorted(all_types), "by-label": r})

        @self.route("/all/compute/<path:uri>")
        def all_compute(uri):
            """
            Compute descriptors over the specified content for all generators
            that function over the data's content type.

            JSON Return format::

                {
                    "success": <bool>

                    "content_type": <str>

                    "message": <str>

                    "descriptors": {  "<label>":  <list[float]>, ... } | None

                    "reference_uri": <str>
                }

            """
            message = "execution nominal"

            data_elem = None
            try:
                data_elem = self.resolve_data_element(uri)
            except ValueError as ex:
                message = "Failed URI resolution: %s" % str(ex)

            descriptors = {}
            finished_loop = False
            if data_elem:
                for l in self.generator_label_configs:
                    if data_elem.content_type() in \
                            self.get_descriptor_inst(l).valid_content_types():
                        d = None
                        try:
                            d = self.generate_descriptor(data_elem, l)
                        except RuntimeError as ex:
                            message = "Descriptor extraction failure: %s" \
                                      % str(ex)
                        except ValueError as ex:
                            message = "Data content type issue: %s" % str(ex)

                        descriptors[l] = d and d.vector().tolist()
                if not descriptors:
                    message = "No descriptors can handle URI content type: %s" \
                              % data_elem.content_type
                else:
                    finished_loop = True

            return flask.jsonify({
                "success": finished_loop,
                "content_type": data_elem.content_type(),
                "message": message,
                "descriptors": descriptors,
                "reference_uri": uri
            })

        @self.route("/<string:descriptor_label>/<path:uri>")
        def compute_descriptor(descriptor_label, uri):
            """

            Data modes for upload/use::

                - local filepath
                - base64
                - http/s URL

            The following sub-sections detail how different URI's can be used.

            Local Filepath
            --------------

            The URI string must be prefixed with ``file://``, followed by the
            full path to the data file to describe.

            Base 64 data
            ------------

            The URI string must be prefixed with "base64://", followed by the
            base64 encoded string. This mode also requires an additional
            ``?content_type=`` to provide data content type information. This
            mode saves the encoded data to temporary file for processing.

            HTTP/S address
            --------------

            This is the default mode when the URI prefix is none of the above.
            This uses the requests module to locally download a data file
            for processing.

            JSON Return format::

                {
                    "success": <bool>

                    "message": <str>

                    "descriptor": <None|list[float]>

                    "reference_uri": <str>
                }

            :type descriptor_label: str
            :type uri: str

            """
            message = "execution nominal"
            descriptor = None

            de = None
            try:
                de = self.resolve_data_element(uri)
            except ValueError as ex:
                message = "URI resolution issue: %s" % str(ex)

            if de:
                try:
                    descriptor = self.generate_descriptor(de, descriptor_label)
                except RuntimeError as ex:
                    message = "Descriptor extraction failure: %s" % str(ex)
                except ValueError as ex:
                    message = "Data content type issue: %s" % str(ex)

            return flask.jsonify({
                "success":
                descriptor is not None,
                "message":
                message,
                "descriptor":
                (descriptor is not None and descriptor.vector().tolist())
                or None,
                "reference_uri":
                uri
            })
示例#51
0
    def __init__(self, json_config):
        """
        Initialize application based of supplied JSON configuration

        :param json_config: JSON configuration dictionary
        :type json_config: dict

        """
        super(DescriptorServiceServer, self).__init__(json_config)

        # Descriptor factory setup
        self.log.info("Initializing DescriptorElementFactory")
        self.descr_elem_factory = DescriptorElementFactory.from_config(
            self.json_config['descriptor_factory']
        )

        # Descriptor generator configuration labels
        #: :type: dict[str, dict]
        self.generator_label_configs = self.json_config['descriptor_generators']

        # Cache of DescriptorGenerator instances so we don't have to continuously
        # initialize them as we get requests.
        self.descriptor_cache = {}
        self.descriptor_cache_lock = multiprocessing.RLock()

        @self.route("/")
        def list_ingest_labels():
            return flask.jsonify({
                "labels": sorted(self.generator_label_configs.iterkeys())
            })

        @self.route("/all/content_types")
        def all_content_types():
            """
            Of available descriptors, what content types are processable, and
            what types are associated to which available descriptor generator.
            """
            all_types = set()
            # Mapping of configuration label to content types that generator
            # can handle
            r = {}
            for l in self.generator_label_configs:
                d = self.get_descriptor_inst(l)
                all_types.update(d.valid_content_types())
                r[l] = sorted(d.valid_content_types())

            return flask.jsonify({
                "all": sorted(all_types),
                "by-label": r
            })

        @self.route("/all/compute/<path:uri>")
        def all_compute(uri):
            """
            Compute descriptors over the specified content for all generators
            that function over the data's content type.

            JSON Return format::

                {
                    "success": <bool>

                    "content_type": <str>

                    "message": <str>

                    "descriptors": {  "<label>":  <list[float]>, ... } | None

                    "reference_uri": <str>
                }

            """
            message = "execution nominal"

            data_elem = None
            try:
                data_elem = self.resolve_data_element(uri)
            except ValueError, ex:
                message = "Failed URI resolution: %s" % str(ex)

            descriptors = {}
            finished_loop = False
            if data_elem:
                for l in self.generator_label_configs:
                    if data_elem.content_type() \
                            in self.get_descriptor_inst(l).valid_content_types():
                        d = None
                        try:
                            d = self.generate_descriptor(data_elem, l)
                        except RuntimeError, ex:
                            message = "Descriptor extraction failure: %s" \
                                      % str(ex)
                        except ValueError, ex:
                            message = "Data content type issue: %s" % str(ex)

                        descriptors[l] = d and d.vector().tolist()
示例#52
0
    def __init__(self, json_config):
        """
        Initialize application based of supplied JSON configuration

        :param json_config: JSON configuration dictionary
        :type json_config: dict

        """
        super(NearestNeighborServiceServer, self).__init__(json_config)

        # Descriptor factory setup
        self.log.info("Initializing DescriptorElementFactory")
        self.descr_elem_factory = DescriptorElementFactory.from_config(
            self.json_config['descriptor_factory']
        )

        # Descriptor generator configuration labels
        #: :type: dict[str, dict]
        self.generator_config = self.json_config['descriptor_generator']

        #: :type: smqtk.algorithms.NearestNeighborsIndex
        self.nn_index = plugin.from_plugin_config(
            json_config['nn_index'],
            get_nn_index_impls
        )

        #: :type: smqtk.algorithms.DescriptorGenerator
        self.descriptor_generator_inst = plugin.from_plugin_config(
                                            self.generator_config,
                                            get_descriptor_generator_impls)

        @self.route("/nn/<path:uri>")
        @self.route("/nn/n=<int:n>/<path:uri>")
        @self.route("/nn/n=<int:n>/<int:start_i>:<int:end_i>/<path:uri>")
        def compute_nearest_neighbors(uri, n=10, start_i=None, end_i=None):
            """
            Data modes for upload/use::

                - local filepath
                - base64
                - http/s URL

            The following sub-sections detail how different URI's can be used.

            Local Filepath
            --------------

            The URI string must be prefixed with ``file://``, followed by the
            full path to the data file to describe.

            Base 64 data
            ------------

            The URI string must be prefixed with "base64://", followed by the
            base64 encoded string. This mode also requires an additional
            ``?content_type=`` to provide data content type information. This
            mode saves the encoded data to temporary file for processing.

            HTTP/S address
            --------------

            This is the default mode when the URI prefix is none of the above.
            This uses the requests module to locally download a data file
            for processing.

            JSON Return format::
                {
                    "success": <bool>

                    "message": <str>

                    "neighbors": <None|list[float]>

                    "reference_uri": <str>
                }

            :type uri: str

            """
            message = "execution nominal"
            descriptor = None

            de = None
            try:
                self.log.debug("Received URI: %s", uri)
                de = self.resolve_data_element(uri)
            except ValueError, ex:
                message = "URI resolution issue: %s" % str(ex)

            if de:
                try:
                    descriptor = self.descriptor_generator_inst.\
                        compute_descriptor(de, self.descr_elem_factory)
                except RuntimeError, ex:
                    message = "Descriptor extraction failure: %s" % str(ex)
                except ValueError, ex:
                    message = "Data content type issue: %s" % str(ex)
示例#53
0
def default_config():
    return {
        "descriptor_generator": make_config(get_descriptor_generator_impls),
        "descriptor_factory": DescriptorElementFactory.get_default_config(),
    }
示例#54
0
def main():
    usage = "%prog [OPTIONS] INPUT_FILE"
    description = """\
Compute a descriptor vector for a given data file, outputting the generated
feature vector to standard out, or to an output file if one was specified (in
numpy format).
"""
    parser = bin_utils.SMQTKOptParser(usage, description=description)

    group_labels = optparse.OptionGroup(parser, "Configuration")
    group_labels.add_option('-c', '--config',
                            default=None,
                            help='Path to the JSON configuration file.')
    group_labels.add_option('--output-config',
                            default=None,
                            help='Optional path to output default JSON '
                                 'configuration to.')
    parser.add_option_group(group_labels)

    group_optional = optparse.OptionGroup(parser, "Optional Parameters")
    group_optional.add_option('--overwrite',
                              action='store_true', default=False,
                              help="Force descriptor computation even if an "
                                   "existing descriptor vector was discovered "
                                   "based on the given content descriptor type "
                                   "and data combination.")
    group_optional.add_option('-o', '--output-filepath',
                              help='Optional path to a file to output feature '
                                   'vector to. Otherwise the feature vector is '
                                   'printed to standard out. Output is saved '
                                   'in numpy binary format (.npy suffix '
                                   'recommended).')
    group_optional.add_option('-v', '--verbose',
                              action='store_true', default=False,
                              help='Print additional debugging messages. All '
                                   'logging goes to standard error.')
    parser.add_option_group(group_optional)

    opts, args = parser.parse_args()

    output_filepath = opts.output_filepath
    overwrite = opts.overwrite
    verbose = opts.verbose

    llevel = logging.DEBUG if verbose else logging.INFO
    bin_utils.initialize_logging(logging.getLogger(), llevel)
    log = logging.getLogger("main")

    bin_utils.output_config(opts.output_config, default_config(), log)

    if not opts.config:
        log.error("No configuration provided")
        exit(1)
    elif not os.path.isfile(opts.config):
        log.error("Configuration file path not valid.")
        exit(1)

    if len(args) == 0:
        log.error("Failed to provide an input file path")
        exit(1)
    if len(args) > 1:
        log.warning("More than one filepath provided as an argument. Only "
                    "computing for the first one.")

    with open(opts.config, 'r') as f:
        config = json.load(f)

    input_filepath = args[0]
    data_element = DataFileElement(input_filepath)

    factory = DescriptorElementFactory.from_config(config['descriptor_factory'])
    #: :type: smqtk.descriptor_generator.DescriptorGenerator
    cd = plugin.from_plugin_config(config['content_descriptor'],
                                   get_descriptor_generator_impls)
    descr_elem = cd.compute_descriptor(data_element, factory, overwrite)
    vec = descr_elem.vector()

    if vec is None:
        log.error("Failed to generate a descriptor vector for the input data!")

    if output_filepath:
        numpy.save(output_filepath, vec)
    else:
        # Construct string, because numpy
        s = []
        # noinspection PyTypeChecker
        for f in vec:
            s.append('%15f' % f)
        print ' '.join(s)
示例#55
0
文件: __init__.py 项目: dhandeo/SMQTK
    def __init__(self, json_config):
        """
        Initialize application based of supplied JSON configuration

        :param json_config: JSON configuration dictionary
        :type json_config: dict

        """
        super(NearestNeighborServiceServer, self).__init__(json_config)

        self.update_index = json_config['update_descriptor_index']

        # Descriptor factory setup
        self._log.info("Initializing DescriptorElementFactory")
        self.descr_elem_factory = DescriptorElementFactory.from_config(
            self.json_config['descriptor_factory']
        )

        #: :type: smqtk.representation.DescriptorIndex | None
        self.descr_index = None
        if self.update_index:
            self._log.info("Initializing DescriptorIndex to update")
            #: :type: smqtk.representation.DescriptorIndex | None
            self.descr_index = plugin.from_plugin_config(
                json_config['descriptor_index'],
                get_descriptor_index_impls()
            )

        #: :type: smqtk.algorithms.NearestNeighborsIndex
        self.nn_index = plugin.from_plugin_config(
            json_config['nn_index'],
            get_nn_index_impls()
        )

        #: :type: smqtk.algorithms.DescriptorGenerator
        self.descriptor_generator_inst = plugin.from_plugin_config(
            self.json_config['descriptor_generator'],
            get_descriptor_generator_impls()
        )

        @self.route("/count", methods=['GET'])
        def count():
            """
            Return the number of elements represented in this index.
            """
            return flask.jsonify(**{
                "count": self.nn_index.count(),
            })

        @self.route("/compute/<path:uri>", methods=["POST"])
        def compute(uri):
            """
            Compute the descriptor for a URI specified data element using the
            configured descriptor generator.

            If the a descriptor index was configured and update was turned on,
            we add the computed descriptor to the index.

            JSON Return format::
                {
                    "success": <bool>

                    "message": <str>

                    "descriptor": <None|list[float]>

                    "reference_uri": <str>
                }

            :param uri: URI data specification.

            """
            descriptor = None
            try:
                _, descriptor = self.generate_descriptor_for_uri(uri)
                message = "Descriptor generated"
                descriptor = map(float, descriptor.vector())
            except ValueError, ex:
                message = "Input value issue: %s" % str(ex)
            except RuntimeError, ex:
                message = "Descriptor extraction failure: %s" % str(ex)