def get_default_config(cls): """ Generate and return a default configuration dictionary for this class. This will be primarily used for generating what the configuration dictionary would look like for this class without instantiating it. By default, we observe what this class's constructor takes as arguments, turning those argument names into configuration dictionary keys. If any of those arguments have defaults, we will add those values into the configuration dictionary appropriately. The dictionary returned should only contain JSON compliant value types. It is not be guaranteed that the configuration dictionary returned from this method is valid for construction of an instance of this class. :return: Default configuration dictionary for the class. :rtype: dict """ default = super(LSHNearestNeighborIndex, cls).get_default_config() lf_default = plugin.make_config(get_lsh_functor_impls()) default['lsh_functor'] = lf_default di_default = plugin.make_config(get_descriptor_index_impls()) default['descriptor_index'] = di_default hi_default = plugin.make_config(get_hash_index_impls()) default['hash_index'] = hi_default default['hash_index_comment'] = "'hash_index' may also be null to " \ "default to a linear index built at " \ "query time." return default
def default_config(): return { "plugins": { "supervised_classifier": plugin.make_config(get_supervised_classifier_impls()), "descriptor_index": plugin.make_config(get_descriptor_index_impls()), "classification_factory": ClassificationElementFactory.get_default_config(), }, "cross_validation": { "truth_labels": None, "num_folds": 6, "random_seed": None, "classification_use_multiprocessing": True, }, "pr_curves": { "enabled": True, "show": False, "output_directory": None, "file_prefix": None, }, "roc_curves": { "enabled": True, "show": False, "output_directory": None, "file_prefix": None, }, }
def from_config(cls, config_dict, merge_default=True): """ Instantiate a new instance of this class given the configuration JSON-compliant dictionary encapsulating initialization arguments. This method should not be called via super unless and instance of the class is desired. :param config_dict: JSON compliant dictionary encapsulating a configuration. :type config_dict: dict :param merge_default: Merge the given configuration on top of the default provided by ``get_default_config``. :type merge_default: bool :return: Constructed instance from the provided config. :rtype: MRPTNearestNeighborsIndex """ if merge_default: cfg = cls.get_default_config() merge_dict(cfg, config_dict) else: cfg = config_dict cfg['descriptor_set'] = \ plugin.from_plugin_config(cfg['descriptor_set'], get_descriptor_index_impls()) return super(MRPTNearestNeighborsIndex, cls).from_config(cfg, False)
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) uuids_list_filepath = config['uuids_list_filepath'] log.info("Initializing ITQ functor") #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor functor = ItqFunctor.from_config(config['itq_config']) log.info("Initializing DescriptorIndex [type=%s]", config['descriptor_index']['type']) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['descriptor_index'], get_descriptor_index_impls(), ) if uuids_list_filepath and os.path.isfile(uuids_list_filepath): def uuids_iter(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() log.info("Loading UUIDs list from file: %s", uuids_list_filepath) d_iter = descriptor_index.get_many_descriptors(uuids_iter()) else: log.info("Using UUIDs from loaded DescriptorIndex (count=%d)", len(descriptor_index)) d_iter = descriptor_index log.info("Fitting ITQ model") functor.fit(d_iter) log.info("Done")
def get_default_config(): return { 'plugins': { 'descriptor_set': plugin.make_config(get_descriptor_index_impls()), 'nn_index': plugin.make_config(get_nn_index_impls()) } }
def get_default_config(cls): """ Generate and return a default configuration dictionary for this class. This will be primarily used for generating what the configuration dictionary would look like for this class without instantiating it. By default, we observe what this class's constructor takes as arguments, turning those argument names into configuration dictionary keys. If any of those arguments have defaults, we will add those values into the configuration dictionary appropriately. The dictionary returned should only contain JSON compliant value types. It is not be guaranteed that the configuration dictionary returned from this method is valid for construction of an instance of this class. :return: Default configuration dictionary for the class. :rtype: dict """ default = super(MRPTNearestNeighborsIndex, cls).get_default_config() di_default = plugin.make_config(get_descriptor_index_impls()) default['descriptor_set'] = di_default return default
def default_config(): return { 'plugins': { 'descriptor_index': plugin.make_config(get_descriptor_index_impls()), } }
def default_config(): return { 'plugins': { 'classifier': plugin.make_config(get_classifier_impls()), 'classification_factory': ClassificationElementFactory.get_default_config(), 'descriptor_index': plugin.make_config(get_descriptor_index_impls()) }, 'utility': { 'train': False, 'csv_filepath': 'CHAMGEME :: PATH :: a csv file', 'output_plot_pr': None, 'output_plot_roc': None, 'output_plot_confusion_matrix': None, 'output_uuid_confusion_matrix': None, 'curve_confidence_interval': False, 'curve_confidence_interval_alpha': 0.4, }, "parallelism": { "descriptor_fetch_cores": 4, "classification_cores": None, }, }
def main(): description = """ Utility script to transform a set of descriptors, specified by UUID, with matching class labels, to a test file usable by libSVM utilities for train/test experiments. The input CSV file is assumed to be of the format: uuid,label ... This is the same as the format requested for other scripts like ``classifier_model_validation.py``. This is very useful for searching for -c and -g parameter values for a training sample of data using the ``tools/grid.py`` script, found in the libSVM source tree. For example: <smqtk_source>/TPL/libsvm-3.1-custom/tools/grid.py \\ -log2c -5,15,2 -log2c 3,-15,-2 -v 5 -out libsvm.grid.out \\ -png libsvm.grid.png -t 0 -w1 3.46713615023 -w2 12.2613240418 \\ output_of_this_script.txt """ args, config = bin_utils.utility_main_helper(default_config, description, extend_parser) log = logging.getLogger(__name__) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) labels_filepath = args.f output_filepath = args.o # Run through labeled UUIDs in input file, getting the descriptor from the # configured index, applying the appropriate integer label and then writing # the formatted line out to the output file. input_uuid_labels = csv.reader(open(labels_filepath)) with open(output_filepath, 'w') as ofile: label2int = {} next_int = 1 uuids, labels = zip(*input_uuid_labels) log.info("Scanning input descriptors and labels") for i, (l, d) in enumerate( itertools.izip(labels, descriptor_index.get_many_descriptors(uuids))): log.debug("%d %s", i, d.uuid()) if l not in label2int: label2int[l] = next_int next_int += 1 ofile.write("%d " % label2int[l] + ' '.join([ "%d:%.12f" % (j + 1, f) for j, f in enumerate(d.vector()) if f != 0.0 ]) + '\n') log.info("Integer label association:") for i, l in sorted((i, l) for l, i in label2int.iteritems()): log.info('\t%d :: %s', i, l)
def get_default_config(cls): """ Generate and return a default configuration dictionary for this class. This will be primarily used for generating what the configuration dictionary would look like for this class without instantiating it. :return: Default configuration dictionary for the class. :rtype: dict """ c = super(NearestNeighborServiceServer, cls).get_default_config() merge_dict( c, { "descriptor_factory": DescriptorElementFactory.get_default_config(), "descriptor_generator": plugin.make_config(get_descriptor_generator_impls()), "nn_index": plugin.make_config(get_nn_index_impls()), "descriptor_index": plugin.make_config(get_descriptor_index_impls()), "update_descriptor_index": False, }) return c
def run_file_list(c, filelist_filepath, checkpoint_filepath, batch_size=None, check_image=False): """ Top level function handling configuration and inputs/outputs. :param c: Configuration dictionary (JSON) :type c: dict :param filelist_filepath: Path to a text file that lists paths to data files, separated by new lines. :type filelist_filepath: str :param checkpoint_filepath: Output file to which we write input filepath to SHA1 (UUID) relationships. :type checkpoint_filepath: :param batch_size: Optional batch size (None default) of data elements to process / descriptors to compute at a time. This causes files and stores to be written to incrementally during processing instead of one single batch transaction at a time. :type batch_size: :param check_image: Enable checking image loading from file before queueing that file for processing. If the check fails, the file is skipped instead of a halting exception being raised. :type check_image: bool """ log = logging.getLogger(__name__) file_paths = [l.strip() for l in open(filelist_filepath)] log.info("Making descriptor factory") factory = DescriptorElementFactory.from_config(c['descriptor_factory']) log.info("Making descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config(c['descriptor_index'], get_descriptor_index_impls()) log.info("Making descriptor generator '%s'", c['descriptor_generator']['type']) #: :type: smqtk.algorithms.DescriptorGenerator generator = plugin.from_plugin_config(c['descriptor_generator'], get_descriptor_generator_impls()) def test_image_load(dfe): try: PIL.Image.open(io.BytesIO(dfe.get_bytes())) return True except IOError, ex: # noinspection PyProtectedMember log.warn( "Failed to convert '%s' bytes into an image " "(error: %s). Skipping", dfe._filepath, str(ex)) return False
def default_config(): return { "descriptor_generator": plugin.make_config(get_descriptor_generator_impls()), "descriptor_factory": DescriptorElementFactory.get_default_config(), "descriptor_index": plugin.make_config(get_descriptor_index_impls()) }
def get_default_config(cls): c = super(IqrService, cls).get_default_config() c_rel_index = plugin.make_config( get_relevancy_index_impls() ) merge_dict(c_rel_index, iqr_session.DFLT_REL_INDEX_CONFIG) merge_dict(c, { "iqr_service": { "positive_seed_neighbors": 500, "plugin_notes": { "relevancy_index_config": "The relevancy index config provided should not have " "persistent storage configured as it will be used in " "such a way that instances are created, built and " "destroyed often.", "descriptor_index": "This is the index from which given positive and " "negative example descriptors are retrieved from. " "Not used for nearest neighbor querying. " "This index must contain all descriptors that could " "possibly be used as positive/negative examples and " "updated accordingly.", "neighbor_index": "This is the neighbor index to pull initial near-" "positive descriptors from.", "classifier_config": "The configuration to use for training and using " "classifiers for the /classifier endpoint. " "When configuring a classifier for use, don't fill " "out model persistence values as many classifiers " "may be created and thrown away during this service's " "operation.", "classification_factory": "Selection of the backend in which classifications " "are stored. The in-memory version is recommended " "because normal caching mechanisms will not account " "for the variety of classifiers that can potentially " "be created via this utility.", }, "plugins": { "relevancy_index_config": c_rel_index, "descriptor_index": plugin.make_config( get_descriptor_index_impls() ), "neighbor_index": plugin.make_config(get_nn_index_impls()), "classifier_config": plugin.make_config(get_classifier_impls()), "classification_factory": ClassificationElementFactory.get_default_config(), } } }) return c
def get_default_config(cls): c = super(IqrService, cls).get_default_config() c_rel_index = plugin.make_config(get_relevancy_index_impls()) merge_dict(c_rel_index, iqr_session.DFLT_REL_INDEX_CONFIG) merge_dict( c, { "iqr_service": { "positive_seed_neighbors": 500, "plugin_notes": { "relevancy_index_config": "The relevancy index config provided should not have " "persistent storage configured as it will be used in " "such a way that instances are created, built and " "destroyed often.", "descriptor_index": "This is the index from which given positive and " "negative example descriptors are retrieved from. " "Not used for nearest neighbor querying. " "This index must contain all descriptors that could " "possibly be used as positive/negative examples and " "updated accordingly.", "neighbor_index": "This is the neighbor index to pull initial near-" "positive descriptors from.", "classifier_config": "The configuration to use for training and using " "classifiers for the /classifier endpoint. " "When configuring a classifier for use, don't fill " "out model persistence values as many classifiers " "may be created and thrown away during this service's " "operation.", "classification_factory": "Selection of the backend in which classifications " "are stored. The in-memory version is recommended " "because normal caching mechanisms will not account " "for the variety of classifiers that can potentially " "be created via this utility.", }, "plugins": { "relevancy_index_config": c_rel_index, "descriptor_index": plugin.make_config(get_descriptor_index_impls()), "neighbor_index": plugin.make_config(get_nn_index_impls()), "classifier_config": plugin.make_config(get_classifier_impls()), "classification_factory": ClassificationElementFactory.get_default_config(), } } }) return c
def from_config(cls, config_dict, merge_default=True): """ Instantiate a new instance of this class given the configuration JSON-compliant dictionary encapsulating initialization arguments. This method should not be called via super unless and instance of the class is desired. :param config_dict: JSON compliant dictionary encapsulating a configuration. :type config_dict: dict :param merge_default: Merge the given configuration on top of the default provided by ``get_default_config``. :type merge_default: bool :return: Constructed instance from the provided config. :rtype: LSHNearestNeighborIndex """ # Controlling merge here so we can control known comment stripping from # default config. if merge_default: merged = cls.get_default_config() merge_dict(merged, config_dict) else: merged = config_dict merged['lsh_functor'] = \ plugin.from_plugin_config(merged['lsh_functor'], get_lsh_functor_impls()) merged['descriptor_index'] = \ plugin.from_plugin_config(merged['descriptor_index'], get_descriptor_index_impls()) # Hash index may be None for a default at-query-time linear indexing if merged['hash_index'] and merged['hash_index']['type']: merged['hash_index'] = \ plugin.from_plugin_config(merged['hash_index'], get_hash_index_impls()) else: cls.get_logger().debug( "No HashIndex impl given. Passing ``None``.") merged['hash_index'] = None # remove possible comment added by default generator if 'hash_index_comment' in merged: del merged['hash_index_comment'] merged['hash2uuids_kvstore'] = \ plugin.from_plugin_config(merged['hash2uuids_kvstore'], get_key_value_store_impls()) return super(LSHNearestNeighborIndex, cls).from_config(merged, False)
def default_config(): return { "utility": { "report_interval": 1.0, "use_multiprocessing": False, "pickle_protocol": -1, }, "plugins": { "descriptor_index": plugin.make_config(get_descriptor_index_impls()), "lsh_functor": plugin.make_config(get_lsh_functor_impls()), }, }
def from_config(cls, config_dict, merge_default=True): """ Instantiate a new instance of this class given the configuration JSON-compliant dictionary encapsulating initialization arguments. This method should not be called via super unless and instance of the class is desired. :param config_dict: JSON compliant dictionary encapsulating a configuration. :type config_dict: dict :param merge_default: Merge the given configuration on top of the default provided by ``get_default_config``. :type merge_default: bool :return: Constructed instance from the provided config. :rtype: LSHNearestNeighborIndex """ if merge_default: cfg = cls.get_default_config() merge_dict(cfg, config_dict) else: cfg = config_dict cfg['descriptor_set'] = plugin.from_plugin_config( cfg['descriptor_set'], get_descriptor_index_impls() ) cfg['uid2idx_kvs'] = plugin.from_plugin_config( cfg['uid2idx_kvs'], get_key_value_store_impls() ) cfg['idx2uid_kvs'] = plugin.from_plugin_config( cfg['idx2uid_kvs'], get_key_value_store_impls() ) if (cfg['index_element'] and cfg['index_element']['type']): index_element = plugin.from_plugin_config( cfg['index_element'], get_data_element_impls()) cfg['index_element'] = index_element else: cfg['index_element'] = None if (cfg['index_param_element'] and cfg['index_param_element']['type']): index_param_element = plugin.from_plugin_config( cfg['index_param_element'], get_data_element_impls()) cfg['index_param_element'] = index_param_element else: cfg['index_param_element'] = None return super(FaissNearestNeighborsIndex, cls).from_config(cfg, False)
def from_config(cls, config_dict, merge_default=True): """ Instantiate a new instance of this class given the configuration JSON-compliant dictionary encapsulating initialization arguments. This method should not be called via super unless and instance of the class is desired. :param config_dict: JSON compliant dictionary encapsulating a configuration. :type config_dict: dict :param merge_default: Merge the given configuration on top of the default provided by ``get_default_config``. :type merge_default: bool :return: Constructed instance from the provided config. :rtype: LSHNearestNeighborIndex """ # Controlling merge here so we can control known comment stripping from # default config. if merge_default: merged = cls.get_default_config() merge_dict(merged, config_dict) else: merged = config_dict merged['lsh_functor'] = \ plugin.from_plugin_config(merged['lsh_functor'], get_lsh_functor_impls()) merged['descriptor_index'] = \ plugin.from_plugin_config(merged['descriptor_index'], get_descriptor_index_impls()) # Hash index may be None for a default at-query-time linear indexing if merged['hash_index'] and merged['hash_index']['type']: merged['hash_index'] = \ plugin.from_plugin_config(merged['hash_index'], get_hash_index_impls()) else: cls.get_logger().debug("No HashIndex impl given. Passing ``None``.") merged['hash_index'] = None # remove possible comment added by default generator if 'hash_index_comment' in merged: del merged['hash_index_comment'] merged['hash2uuids_kvstore'] = \ plugin.from_plugin_config(merged['hash2uuids_kvstore'], get_key_value_store_impls()) return super(LSHNearestNeighborIndex, cls).from_config(merged, False)
def __init__(self, json_config): super(IqrService, self).__init__(json_config) sc_config = json_config['iqr_service']['session_control'] # Initialize from config self.positive_seed_neighbors = sc_config['positive_seed_neighbors'] self.classifier_config = \ json_config['iqr_service']['plugins']['classifier_config'] self.classification_factory = \ ClassificationElementFactory.from_config( json_config['iqr_service']['plugins']['classification_factory'] ) #: :type: smqtk.representation.DescriptorIndex self.descriptor_index = plugin.from_plugin_config( json_config['iqr_service']['plugins']['descriptor_index'], get_descriptor_index_impls(), ) #: :type: smqtk.algorithms.NearestNeighborsIndex self.neighbor_index = plugin.from_plugin_config( json_config['iqr_service']['plugins']['neighbor_index'], get_nn_index_impls(), ) self.rel_index_config = \ json_config['iqr_service']['plugins']['relevancy_index_config'] # Record of trained classifiers for a session. Session classifier # modifications locked under the parent session's global lock. #: :type: dict[collections.Hashable, smqtk.algorithms.SupervisedClassifier | None] self.session_classifiers = {} # Control for knowing when a new classifier should be trained for a # session (True == train new classifier). Modification for specific # sessions under parent session's lock. #: :type: dict[collections.Hashable, bool] self.session_classifier_dirty = {} def session_expire_callback(session): """ :type session: smqtk.iqr.IqrSession """ with session: self._log.debug("Removing session %s classifier", session.uuid) del self.session_classifiers[session.uuid] del self.session_classifier_dirty[session.uuid] self.controller = iqr_controller.IqrController( sc_config['session_expiration']['enabled'], sc_config['session_expiration']['check_interval_seconds'], session_expire_callback) self.session_timeout = \ sc_config['session_expiration']['session_timeout']
def run_file_list(c, filelist_filepath, checkpoint_filepath, batch_size=None, check_image=False): """ Top level function handling configuration and inputs/outputs. :param c: Configuration dictionary (JSON) :type c: dict :param filelist_filepath: Path to a text file that lists paths to data files, separated by new lines. :type filelist_filepath: str :param checkpoint_filepath: Output file to which we write input filepath to SHA1 (UUID) relationships. :type checkpoint_filepath: :param batch_size: Optional batch size (None default) of data elements to process / descriptors to compute at a time. This causes files and stores to be written to incrementally during processing instead of one single batch transaction at a time. :type batch_size: """ log = logging.getLogger(__name__) file_paths = [l.strip() for l in open(filelist_filepath)] log.info("Making descriptor factory") factory = DescriptorElementFactory.from_config(c['descriptor_factory']) log.info("Making descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config(c['descriptor_index'], get_descriptor_index_impls()) log.info("Making descriptor generator '%s'", c['descriptor_generator']['type']) #: :type: smqtk.algorithms.DescriptorGenerator generator = plugin.from_plugin_config(c['descriptor_generator'], get_descriptor_generator_impls()) def test_image_load(dfe): try: PIL.Image.open(io.BytesIO(dfe.get_bytes())) return True except IOError, ex: # noinspection PyProtectedMember log.warn("Failed to convert '%s' bytes into an image " "(error: %s). Skipping", dfe._filepath, str(ex)) return False
def default_config(): return { "utility": { "report_interval": 1.0, "use_multiprocessing": False, }, "plugins": { "descriptor_index": plugin.make_config(get_descriptor_index_impls()), "lsh_functor": plugin.make_config(get_lsh_functor_impls()), "hash2uuid_kvstore": plugin.make_config(get_key_value_store_impls()), }, }
def main(): # Print help and exit if no arguments were passed if len(sys.argv) == 1: get_cli_parser().print_help() sys.exit(1) args = get_cli_parser().parse_args() config = utility_main_helper(get_default_config, args) log = logging.getLogger(__name__) log.debug('Showing debug messages.') #: :type: smqtk.representation.DescriptorIndex descriptor_set = plugin.from_plugin_config( config['plugins']['descriptor_set'], get_descriptor_index_impls() ) #: :type: smqtk.algorithms.NearestNeighborsIndex nearest_neighbor_index = plugin.from_plugin_config( config['plugins']['nn_index'], get_nn_index_impls() ) # noinspection PyShadowingNames def nearest_neighbors(descriptor, n): if n == 0: n = len(nearest_neighbor_index) uuids, descriptors = nearest_neighbor_index.nn(descriptor, n) # Strip first result (itself) and create list of (uuid, distance) return list(zip([x.uuid() for x in uuids[1:]], descriptors[1:])) if args.uuid_list is not None and not os.path.exists(args.uuid_list): log.error('Invalid file list path: %s', args.uuid_list) exit(103) elif args.num < 0: log.error('Number of nearest neighbors must be >= 0') exit(105) if args.uuid_list is not None: with open(args.uuid_list, 'r') as infile: for line in infile: descriptor = descriptor_set.get_descriptor(line.strip()) print(descriptor.uuid()) for neighbor in nearest_neighbors(descriptor, args.num): print('%s,%f' % neighbor) else: for (uuid, descriptor) in descriptor_set.iteritems(): print(uuid) for neighbor in nearest_neighbors(descriptor, args.num): print('%s,%f' % neighbor)
def main(): description = """ Tool for training the ITQ functor algorithm's model on descriptors in an index. By default, we use all descriptors in the configured index (``uuids_list_filepath`` is not given a value). The ``uuids_list_filepath`` configuration property is optional and should be used to specify a sub-set of descriptors in the configured index to train on. This only works if the stored descriptors' UUID is a type of string. """ args, config = bin_utils.utility_main_helper(default_config, description) log = logging.getLogger(__name__) uuids_list_filepath = config['uuids_list_filepath'] log.info("Initializing ITQ functor") #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor functor = ItqFunctor.from_config(config['itq_config']) log.info("Initializing DescriptorIndex [type=%s]", config['descriptor_index']['type']) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['descriptor_index'], get_descriptor_index_impls(), ) if uuids_list_filepath and os.path.isfile(uuids_list_filepath): def uuids_iter(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() log.info("Loading UUIDs list from file: %s", uuids_list_filepath) d_iter = descriptor_index.get_many_descriptors(uuids_iter()) else: log.info("Using UUIDs from loaded DescriptorIndex (count=%d)", len(descriptor_index)) d_iter = descriptor_index log.info("Fitting ITQ model") functor.fit(d_iter) log.info("Done")
def default_config(): return { "utility": { "classify_overwrite": False, "parallel": { "use_multiprocessing": False, "index_extraction_cores": None, "classification_cores": None, } }, "plugins": { "classifier": plugin.make_config(get_classifier_impls()), "classification_factory": plugin.make_config(get_classification_element_impls()), "descriptor_index": plugin.make_config(get_descriptor_index_impls()), } }
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) labels_filepath = args.f output_filepath = args.o # Run through labeled UUIDs in input file, getting the descriptor from the # configured index, applying the appropriate integer label and then writing # the formatted line out to the output file. input_uuid_labels = csv.reader(open(labels_filepath)) with open(output_filepath, 'w') as ofile: label2int = {} next_int = 1 uuids, labels = list(zip(*input_uuid_labels)) log.info("Scanning input descriptors and labels") for i, (l, d) in enumerate( zip(labels, descriptor_index.get_many_descriptors(uuids))): log.debug("%d %s", i, d.uuid()) if l not in label2int: label2int[l] = next_int next_int += 1 ofile.write( "%d " % label2int[l] + ' '.join(["%d:%.12f" % (j+1, f) for j, f in enumerate(d.vector()) if f != 0.0]) + '\n' ) log.info("Integer label association:") for i, l in sorted((i, l) for l, i in six.iteritems(label2int)): log.info('\t%d :: %s', i, l)
def default_config(): return { "utility": { "classify_overwrite": False, "parallel": { "use_multiprocessing": False, "index_extraction_cores": None, "classification_cores": None, } }, "plugins": { "classifier": plugin.make_config(get_classifier_impls()), "classification_factory": plugin.make_config( get_classification_element_impls() ), "descriptor_index": plugin.make_config( get_descriptor_index_impls() ), } }
def get_default_config(cls): """ Generate and return a default configuration dictionary for this class. This will be primarily used for generating what the configuration dictionary would look like for this class without instantiating it. :return: Default configuration dictionary for the class. :rtype: dict """ c = super(NearestNeighborServiceServer, cls).get_default_config() merge_dict(c, { "descriptor_factory": DescriptorElementFactory.get_default_config(), "descriptor_generator": plugin.make_config(get_descriptor_generator_impls()), "nn_index": plugin.make_config(get_nn_index_impls()), "descriptor_index": plugin.make_config(get_descriptor_index_impls()), "update_descriptor_index": False, }) return c
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) labels_filepath = args.f output_filepath = args.o # Run through labeled UUIDs in input file, getting the descriptor from the # configured index, applying the appropriate integer label and then writing # the formatted line out to the output file. input_uuid_labels = csv.reader(open(labels_filepath)) with open(output_filepath, 'w') as ofile: label2int = {} next_int = 1 uuids, labels = zip(*input_uuid_labels) log.info("Scanning input descriptors and labels") for i, (l, d) in enumerate( itertools.izip(labels, descriptor_index.get_many_descriptors(uuids))): log.debug("%d %s", i, d.uuid()) if l not in label2int: label2int[l] = next_int next_int += 1 ofile.write("%d " % label2int[l] + ' '.join([ "%d:%.12f" % (j + 1, f) for j, f in enumerate(d.vector()) if f != 0.0 ]) + '\n') log.info("Integer label association:") for i, l in sorted((i, l) for l, i in label2int.iteritems()): log.info('\t%d :: %s', i, l)
def __init__(self, json_config): super(IqrService, self).__init__(json_config) # Initialize from config self.positive_seed_neighbors = \ json_config['iqr_service']['positive_seed_neighbors'] self.classifier_config = \ json_config['iqr_service']['plugins']['classifier_config'] self.classification_factory = \ ClassificationElementFactory.from_config( json_config['iqr_service']['plugins']['classification_factory'] ) #: :type: smqtk.representation.DescriptorIndex self.descriptor_index = plugin.from_plugin_config( json_config['iqr_service']['plugins']['descriptor_index'], get_descriptor_index_impls(), ) #: :type: smqtk.algorithms.NearestNeighborsIndex self.neighbor_index = plugin.from_plugin_config( json_config['iqr_service']['plugins']['neighbor_index'], get_nn_index_impls(), ) self.rel_index_config = \ json_config['iqr_service']['plugins']['relevancy_index_config'] self.controller = iqr_controller.IqrController() # Record of trained classifiers for a session. Session classifier # modifications locked under the parent session's global lock. #: :type: dict[collections.Hashable, smqtk.algorithms.SupervisedClassifier | None] self.session_classifiers = {} # Control for knowing when a new classifier should be trained for a # session (True == train new classifier). Modification for specific # sessions under parent session's lock. #: :type: dict[collections.Hashable, bool] self.session_classifier_dirty = {}
def main(): description = """ Utility for validating a given classifier implementation's model against some labeled testing data, outputting PR and ROC curve plots with area-under-curve score values. This utility can optionally be used train a supervised classifier model if the given classifier model configuration does not exist and a second CSV file listing labeled training data is provided. Training will be attempted if ``train`` is set to true. If training is performed, we exit after training completes. A ``SupervisedClassifier`` sub-classing implementation must be configured We expect the test and train CSV files in the column format: ... <UUID>,<label> ... The UUID is of the descriptor to which the label applies. The label may be any arbitrary string value, but all labels must be consistent in application. Some metrics presented assume the highest confidence class as the single predicted class for an element: - confusion matrix The output UUID confusion matrix is a JSON dictionary where the top-level keys are the true labels, and the inner dictionary is the mapping of predicted labels to the UUIDs of the classifications/descriptors that yielded the prediction. Again, this is based on the maximum probability label for a classification result (T=0.5). """ args, config = bin_utils.utility_main_helper(default_config, description) log = logging.getLogger(__name__) # # Initialize stuff from configuration # #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config( config['plugins']['classifier'], get_classifier_impls() ) #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) uuid2label_filepath = config['utility']['csv_filepath'] do_train = config['utility']['train'] output_uuid_cm = config['utility']['output_uuid_confusion_matrix'] plot_filepath_pr = config['utility']['output_plot_pr'] plot_filepath_roc = config['utility']['output_plot_roc'] plot_filepath_cm = config['utility']['output_plot_confusion_matrix'] plot_ci = config['utility']['curve_confidence_interval'] plot_ci_alpha = config['utility']['curve_confidence_interval_alpha'] # # Construct mapping of label to the DescriptorElement instances for that # described by that label. # log.info("Loading descriptors by UUID") def iter_uuid_label(): """ Iterate through UUIDs in specified file """ with open(uuid2label_filepath) as uuid2label_file: reader = csv.reader(uuid2label_file) for r in reader: # TODO: This will need to be updated to handle multiple labels # per descriptor. yield r[0], r[1] def get_descr(r): """ Fetch descriptors from configured index """ uuid, truth_label = r return truth_label, descriptor_index.get_descriptor(uuid) tlabel_element_iter = parallel.parallel_map( get_descr, iter_uuid_label(), name="cmv_get_descriptors", use_multiprocessing=True, cores=config['parallelism']['descriptor_fetch_cores'], ) # Map of truth labels to descriptors of labeled data #: :type: dict[str, list[smqtk.representation.DescriptorElement]] tlabel2descriptors = {} for tlabel, d in tlabel_element_iter: tlabel2descriptors.setdefault(tlabel, []).append(d) # Train classifier if the one given has a ``train`` method and training # was turned enabled. if do_train: if isinstance(classifier, SupervisedClassifier): log.info("Training classifier model") classifier.train(tlabel2descriptors) exit(0) else: ValueError("Configured classifier is not a SupervisedClassifier " "type and does not support training.") # # Apply classifier to descriptors for predictions # # Truth label to predicted classification results #: :type: dict[str, set[smqtk.representation.ClassificationElement]] tlabel2classifications = {} for tlabel, descriptors in tlabel2descriptors.iteritems(): tlabel2classifications[tlabel] = \ set(classifier.classify_async( descriptors, classification_factory, use_multiprocessing=True, procs=config['parallelism']['classification_cores'], ri=1.0, ).values()) log.info("Truth label counts:") for l in sorted(tlabel2classifications): log.info(" %s :: %d", l, len(tlabel2classifications[l])) # # Confusion Matrix # conf_mat, labels = gen_confusion_matrix(tlabel2classifications) log.info("Confusion_matrix") log_cm(log.info, conf_mat, labels) if plot_filepath_cm: plot_cm(conf_mat, labels, plot_filepath_cm) # CM of descriptor UUIDs to output json if output_uuid_cm: # Top dictionary keys are true labels, inner dictionary keys are UUID # predicted labels. log.info("Computing UUID Confusion Matrix") #: :type: dict[str, dict[str, set | list]] uuid_cm = {} for tlabel in tlabel2classifications: uuid_cm[tlabel] = collections.defaultdict(set) for c in tlabel2classifications[tlabel]: uuid_cm[tlabel][c.max_label()].add(c.uuid) # convert sets to lists for plabel in uuid_cm[tlabel]: uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel]) with open(output_uuid_cm, 'w') as f: log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm) json.dump(uuid_cm, f, indent=2, separators=(',', ': ')) # # Create PR/ROC curves via scikit learn tools # if plot_filepath_pr: log.info("Making PR curve") make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci, plot_ci_alpha) if plot_filepath_roc: log.info("Making ROC curve") make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci, plot_ci_alpha)
def main(): description = """ Compute LSH hash codes based on the provided functor on specific descriptors from the configured index given a file-list of UUIDs. When using an input file-list of UUIDs, we require that the UUIDs of indexed descriptors be strings, or equality comparable to the UUIDs' string representation. This script can be used to live update the ``hash2uuid_cache_filepath`` model file for the ``LSHNearestNeighborIndex`` algorithm as output dictionary format is the same as used by that implementation. """ args, config = bin_utils.utility_main_helper(default_config, description, extend_parser) log = logging.getLogger(__name__) # # Load configuration contents # uuid_list_filepath = args.uuids_list hash2uuids_input_filepath = args.input_hash2uuids hash2uuids_output_filepath = args.output_hash2uuids report_interval = config['utility']['report_interval'] use_multiprocessing = config['utility']['use_multiprocessing'] pickle_protocol = config['utility']['pickle_protocol'] # # Checking parameters # if not hash2uuids_output_filepath: raise ValueError("No hash2uuids map output file provided!") # # Loading stuff # log.info("Loading descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) log.info("Loading LSH functor") #: :type: smqtk.algorithms.LshFunctor lsh_functor = plugin.from_plugin_config( config['plugins']['lsh_functor'], get_lsh_functor_impls() ) def iter_uuids(): if uuid_list_filepath: log.info("Using UUIDs list file") with open(uuid_list_filepath) as f: for l in f: yield l.strip() else: log.info("Using all UUIDs resent in descriptor index") for k in descriptor_index.iterkeys(): yield k # load map if it exists, else start with empty dictionary if hash2uuids_input_filepath and os.path.isfile(hash2uuids_input_filepath): log.info("Loading hash2uuids mapping") with open(hash2uuids_input_filepath) as f: hash2uuids = cPickle.load(f) else: log.info("Creating new hash2uuids mapping for output") hash2uuids = {} # # Compute codes # log.info("Starting hash code computation") compute_hash_codes( uuids_for_processing(iter_uuids(), hash2uuids), descriptor_index, lsh_functor, hash2uuids, report_interval=report_interval, use_mp=use_multiprocessing, ) # # Output results # tmp_output_filepath = hash2uuids_output_filepath + '.WRITING' log.info("Writing hash-to-uuids map to disk: %s", tmp_output_filepath) file_utils.safe_create_dir(os.path.dirname(hash2uuids_output_filepath)) with open(tmp_output_filepath, 'wb') as f: cPickle.dump(hash2uuids, f, pickle_protocol) log.info("Moving on top of input: %s", hash2uuids_output_filepath) os.rename(tmp_output_filepath, hash2uuids_output_filepath) log.info("Done")
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # # Initialize stuff from configuration # #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config(config['plugins']['classifier'], get_classifier_impls()) #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory']) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) uuid2label_filepath = config['utility']['csv_filepath'] do_train = config['utility']['train'] output_uuid_cm = config['utility']['output_uuid_confusion_matrix'] plot_filepath_pr = config['utility']['output_plot_pr'] plot_filepath_roc = config['utility']['output_plot_roc'] plot_filepath_cm = config['utility']['output_plot_confusion_matrix'] plot_ci = config['utility']['curve_confidence_interval'] plot_ci_alpha = config['utility']['curve_confidence_interval_alpha'] # # Construct mapping of label to the DescriptorElement instances for that # described by that label. # log.info("Loading descriptors by UUID") def iter_uuid_label(): """ Iterate through UUIDs in specified file """ with open(uuid2label_filepath) as uuid2label_file: reader = csv.reader(uuid2label_file) for r in reader: # TODO: This will need to be updated to handle multiple labels # per descriptor. yield r[0], r[1] def get_descr(r): """ Fetch descriptors from configured index """ uuid, truth_label = r return truth_label, descriptor_index.get_descriptor(uuid) tlabel_element_iter = parallel.parallel_map( get_descr, iter_uuid_label(), name="cmv_get_descriptors", use_multiprocessing=True, cores=config['parallelism']['descriptor_fetch_cores'], ) # Map of truth labels to descriptors of labeled data #: :type: dict[str, list[smqtk.representation.DescriptorElement]] tlabel2descriptors = {} for tlabel, d in tlabel_element_iter: tlabel2descriptors.setdefault(tlabel, []).append(d) # Train classifier if the one given has a ``train`` method and training # was turned enabled. if do_train: if isinstance(classifier, SupervisedClassifier): log.info("Training classifier model") classifier.train(tlabel2descriptors) exit(0) else: ValueError("Configured classifier is not a SupervisedClassifier " "type and does not support training.") # # Apply classifier to descriptors for predictions # # Truth label to predicted classification results #: :type: dict[str, set[smqtk.representation.ClassificationElement]] tlabel2classifications = {} for tlabel, descriptors in tlabel2descriptors.items(): tlabel2classifications[tlabel] = \ set(classifier.classify_async( descriptors, classification_factory, use_multiprocessing=True, procs=config['parallelism']['classification_cores'], ri=1.0, ).values()) log.info("Truth label counts:") for l in sorted(tlabel2classifications): log.info(" %s :: %d", l, len(tlabel2classifications[l])) # # Confusion Matrix # conf_mat, labels = gen_confusion_matrix(tlabel2classifications) log.info("Confusion_matrix") log_cm(log.info, conf_mat, labels) if plot_filepath_cm: plot_cm(conf_mat, labels, plot_filepath_cm) # CM of descriptor UUIDs to output json if output_uuid_cm: # Top dictionary keys are true labels, inner dictionary keys are UUID # predicted labels. log.info("Computing UUID Confusion Matrix") #: :type: dict[str, dict[str, set | list]] uuid_cm = {} for tlabel in tlabel2classifications: uuid_cm[tlabel] = collections.defaultdict(set) for c in tlabel2classifications[tlabel]: uuid_cm[tlabel][c.max_label()].add(c.uuid) # convert sets to lists for plabel in uuid_cm[tlabel]: uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel]) with open(output_uuid_cm, 'w') as f: log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm) json.dump(uuid_cm, f, indent=2, separators=(',', ': ')) # # Create PR/ROC curves via scikit learn tools # if plot_filepath_pr: log.info("Making PR curve") make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci, plot_ci_alpha) if plot_filepath_roc: log.info("Making ROC curve") make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci, plot_ci_alpha)
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # # Initialize stuff from configuration # #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config( config['plugins']['classifier'], get_classifier_impls() ) #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) uuid2label_filepath = config['utility']['csv_filepath'] do_train = config['utility']['train'] output_uuid_cm = config['utility']['output_uuid_confusion_matrix'] plot_filepath_pr = config['utility']['output_plot_pr'] plot_filepath_roc = config['utility']['output_plot_roc'] plot_filepath_cm = config['utility']['output_plot_confusion_matrix'] plot_ci = config['utility']['curve_confidence_interval'] plot_ci_alpha = config['utility']['curve_confidence_interval_alpha'] # # Construct mapping of label to the DescriptorElement instances for that # described by that label. # log.info("Loading descriptors by UUID") def iter_uuid_label(): """ Iterate through UUIDs in specified file """ with open(uuid2label_filepath) as uuid2label_file: reader = csv.reader(uuid2label_file) for r in reader: # TODO: This will need to be updated to handle multiple labels # per descriptor. yield r[0], r[1] def get_descr(r): """ Fetch descriptors from configured index """ uuid, truth_label = r return truth_label, descriptor_index.get_descriptor(uuid) tlabel_element_iter = parallel.parallel_map( get_descr, iter_uuid_label(), name="cmv_get_descriptors", use_multiprocessing=True, cores=config['parallelism']['descriptor_fetch_cores'], ) # Map of truth labels to descriptors of labeled data #: :type: dict[str, list[smqtk.representation.DescriptorElement]] tlabel2descriptors = {} for tlabel, d in tlabel_element_iter: tlabel2descriptors.setdefault(tlabel, []).append(d) # Train classifier if the one given has a ``train`` method and training # was turned enabled. if do_train: if isinstance(classifier, SupervisedClassifier): log.info("Training classifier model") classifier.train(tlabel2descriptors) exit(0) else: ValueError("Configured classifier is not a SupervisedClassifier " "type and does not support training.") # # Apply classifier to descriptors for predictions # # Truth label to predicted classification results #: :type: dict[str, set[smqtk.representation.ClassificationElement]] tlabel2classifications = {} for tlabel, descriptors in six.iteritems(tlabel2descriptors): tlabel2classifications[tlabel] = \ set(classifier.classify_async( descriptors, classification_factory, use_multiprocessing=True, procs=config['parallelism']['classification_cores'], ri=1.0, ).values()) log.info("Truth label counts:") for l in sorted(tlabel2classifications): log.info(" %s :: %d", l, len(tlabel2classifications[l])) # # Confusion Matrix # conf_mat, labels = gen_confusion_matrix(tlabel2classifications) log.info("Confusion_matrix") log_cm(log.info, conf_mat, labels) if plot_filepath_cm: plot_cm(conf_mat, labels, plot_filepath_cm) # Confusion Matrix of descriptor UUIDs to output json if output_uuid_cm: # Top dictionary keys are true labels, inner dictionary keys are UUID # predicted labels. log.info("Computing UUID Confusion Matrix") #: :type: dict[str, dict[collections.Hashable, set]] uuid_cm = {} for tlabel in tlabel2classifications: uuid_cm[tlabel] = collections.defaultdict(set) for c in tlabel2classifications[tlabel]: uuid_cm[tlabel][c.max_label()].add(c.uuid) # convert sets to lists for plabel in uuid_cm[tlabel]: uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel]) with open(output_uuid_cm, 'w') as f: log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm) json.dump(uuid_cm, f, indent=2, separators=(',', ': ')) # # Create PR/ROC curves via scikit learn tools # if plot_filepath_pr: log.info("Making PR curve") make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci, plot_ci_alpha) if plot_filepath_roc: log.info("Making ROC curve") make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci, plot_ci_alpha)
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # # Load configuration contents # uuid_list_filepath = args.uuids_list report_interval = config['utility']['report_interval'] use_multiprocessing = config['utility']['use_multiprocessing'] # # Checking input parameters # if (uuid_list_filepath is not None) and \ not os.path.isfile(uuid_list_filepath): raise ValueError("UUIDs list file does not exist!") # # Loading stuff # log.info("Loading descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) log.info("Loading LSH functor") #: :type: smqtk.algorithms.LshFunctor lsh_functor = plugin.from_plugin_config(config['plugins']['lsh_functor'], get_lsh_functor_impls()) log.info("Loading Key/Value store") #: :type: smqtk.representation.KeyValueStore hash2uuids_kvstore = plugin.from_plugin_config( config['plugins']['hash2uuid_kvstore'], get_key_value_store_impls()) # Iterate either over what's in the file given, or everything in the # configured index. def iter_uuids(): if uuid_list_filepath: log.info("Using UUIDs list file") with open(uuid_list_filepath) as f: for l in f: yield l.strip() else: log.info("Using all UUIDs resent in descriptor index") for k in descriptor_index.keys(): yield k # # Compute codes # log.info("Starting hash code computation") kv_update = {} for uuid, hash_int in \ compute_hash_codes(uuids_for_processing(iter_uuids(), hash2uuids_kvstore), descriptor_index, lsh_functor, report_interval, use_multiprocessing, True): # Get original value in KV-store if not in update dict. if hash_int not in kv_update: kv_update[hash_int] = hash2uuids_kvstore.get(hash_int, set()) kv_update[hash_int] |= {uuid} if kv_update: log.info("Updating KV store... (%d keys)" % len(kv_update)) hash2uuids_kvstore.add_many(kv_update) log.info("Done")
def classifier_kfold_validation(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # # Load configurations / Setup data # use_mp = config['cross_validation']['classification_use_multiprocessing'] pr_enabled = config['pr_curves']['enabled'] pr_output_dir = config['pr_curves']['output_directory'] pr_file_prefix = config['pr_curves']['file_prefix'] or '' pr_show = config['pr_curves']['show'] roc_enabled = config['roc_curves']['enabled'] roc_output_dir = config['roc_curves']['output_directory'] roc_file_prefix = config['roc_curves']['file_prefix'] or '' roc_show = config['roc_curves']['show'] log.info("Initializing DescriptorIndex (%s)", config['plugins']['descriptor_index']['type']) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) log.info("Loading classifier configuration") #: :type: dict classifier_config = config['plugins']['supervised_classifier'] classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory']) log.info("Loading truth data") #: :type: list[str] uuids = [] #: :type: list[str] truth_labels = [] with open(config['cross_validation']['truth_labels']) as f: f_csv = csv.reader(f) for row in f_csv: uuids.append(row[0]) truth_labels.append(row[1]) #: :type: numpy.ndarray[str] uuids = numpy.array(uuids) #: :type: numpy.ndarray[str] truth_labels = numpy.array(truth_labels) # # Cross validation # kfolds = sklearn.cross_validation.StratifiedKFold( truth_labels, config['cross_validation']['num_folds'], random_state=config['cross_validation']['random_seed']) """ Truth and classification probability results for test data per fold. Format: { 0: { '<label>': { "truth": [...], # Parallel truth and classification "proba": [...], # probability values }, ... }, ... } """ fold_data = {} i = 0 for train, test in kfolds: log.info("Fold %d", i) log.info("-- %d training examples", len(train)) log.info("-- %d test examples", len(test)) fold_data[i] = {} log.info("-- creating classifier") #: :type: SupervisedClassifier classifier = plugin.from_plugin_config( classifier_config, get_supervised_classifier_impls()) log.info("-- gathering descriptors") #: :type: dict[str, list[smqtk.representation.DescriptorElement]] pos_map = {} for idx in train: if truth_labels[idx] not in pos_map: pos_map[truth_labels[idx]] = [] pos_map[truth_labels[idx]].append( descriptor_index.get_descriptor(uuids[idx])) log.info("-- Training classifier") classifier.train(pos_map) log.info("-- Classifying test set") m = classifier.classify_async( (descriptor_index.get_descriptor(uuids[idx]) for idx in test), classification_factory, use_multiprocessing=use_mp, ri=1.0) uuid2c = dict( (d.uuid(), c.get_classification()) for d, c in m.iteritems()) log.info("-- Pairing truth and computed probabilities") # Only considering positive labels for t_label in pos_map: fold_data[i][t_label] = { "truth": [l == t_label for l in truth_labels[test]], "proba": [uuid2c[uuid][t_label] for uuid in uuids[test]] } i += 1 # # Curve generation # if pr_enabled: make_pr_curves(fold_data, pr_output_dir, pr_file_prefix, pr_show) if roc_enabled: make_roc_curves(fold_data, roc_output_dir, roc_file_prefix, roc_show)
def __init__(self, json_config): """ Initialize application based of supplied JSON configuration :param json_config: JSON configuration dictionary :type json_config: dict """ super(NearestNeighborServiceServer, self).__init__(json_config) self.update_index = json_config['update_descriptor_index'] # Descriptor factory setup self._log.info("Initializing DescriptorElementFactory") self.descr_elem_factory = DescriptorElementFactory.from_config( self.json_config['descriptor_factory']) #: :type: smqtk.representation.DescriptorIndex | None self.descr_index = None if self.update_index: self._log.info("Initializing DescriptorIndex to update") #: :type: smqtk.representation.DescriptorIndex | None self.descr_index = plugin.from_plugin_config( json_config['descriptor_index'], get_descriptor_index_impls()) #: :type: smqtk.algorithms.NearestNeighborsIndex self.nn_index = plugin.from_plugin_config(json_config['nn_index'], get_nn_index_impls()) #: :type: smqtk.algorithms.DescriptorGenerator self.descriptor_generator_inst = plugin.from_plugin_config( self.json_config['descriptor_generator'], get_descriptor_generator_impls()) @self.route("/count", methods=['GET']) def count(): """ Return the number of elements represented in this index. """ return flask.jsonify(**{ "count": self.nn_index.count(), }) @self.route("/compute/<path:uri>", methods=["POST"]) def compute(uri): """ Compute the descriptor for a URI specified data element using the configured descriptor generator. See ``compute_nearest_neighbors`` method docstring for URI specifications accepted. If the a descriptor index was configured and update was turned on, we add the computed descriptor to the index. JSON Return format:: { "success": <bool> "message": <str> "descriptor": <None|list[float]> "reference_uri": <str> } :param uri: URI data specification. """ descriptor = None try: descriptor = self.generate_descriptor_for_uri(uri) message = "Descriptor generated" descriptor = list(map(float, descriptor.vector())) except ValueError as ex: message = "Input value issue: %s" % str(ex) except RuntimeError as ex: message = "Descriptor extraction failure: %s" % str(ex) return flask.jsonify( success=descriptor is not None, message=message, descriptor=descriptor, reference_uri=uri, ) @self.route("/nn/<path:uri>") @self.route("/nn/n=<int:n>/<path:uri>") @self.route("/nn/n=<int:n>/<int:start_i>:<int:end_i>/<path:uri>") def compute_nearest_neighbors(uri, n=10, start_i=None, end_i=None): """ Data modes for upload/use: - local filepath - base64 - http/s URL - existing data/descriptor UUID The following sub-sections detail how different URI's can be used. Local Filepath -------------- The URI string must be prefixed with ``file://``, followed by the full path to the data file to describe. Base 64 data ------------ The URI string must be prefixed with "base64://", followed by the base64 encoded string. This mode also requires an additional ``?content_type=`` to provide data content type information. This mode saves the encoded data to temporary file for processing. HTTP/S address -------------- This is the default mode when the URI prefix is none of the above. This uses the requests module to locally download a data file for processing. Existing Data/Descriptor by UUID -------------------------------- When given a uri prefixed with "uuid://", we interpret the remainder of the uri as the UUID of a descriptor already present in the configured descriptor index. If the given UUID is not present in the index, a KeyError is raised. JSON Return format ------------------ { "success": <bool> "message": <str> "neighbors": <None|list[float]> "reference_uri": <str> } :param n: Number of neighbors to query for :param start_i: The starting index of the neighbor vectors to slice into for return. :param end_i: The ending index of the neighbor vectors to slice into for return. :type uri: str """ descriptor = None try: descriptor = self.generate_descriptor_for_uri(uri) message = "descriptor computed" except ValueError as ex: message = "Input data issue: %s" % str(ex) except RuntimeError as ex: message = "Descriptor generation failure: %s" % str(ex) # Base pagination slicing based on provided start and end indices, # otherwise clamp to beginning/ending of queried neighbor sequence. page_slice = slice(start_i or 0, end_i or n) neighbors = [] dists = [] if descriptor is not None: try: neighbors, dists = \ self.nn_index.nn(descriptor, n) except ValueError as ex: message = "Descriptor or index related issue: %s" % str(ex) # TODO: Return the optional descriptor vectors for the neighbors # noinspection PyTypeChecker d = { "success": bool(descriptor is not None), "message": message, "neighbors": [n.uuid() for n in neighbors[page_slice]], "distances": dists[page_slice], "reference_uri": uri } return flask.jsonify(d)
def classifier_kfold_validation(): description = """ Helper utility for cross validating a supervised classifier configuration. The classifier used should NOT be configured to save its model since this process requires us to train the classifier multiple times. Configuration ------------- - plugins - supervised_classifier Supervised Classifier implementation configuration to use. This should not be set to use a persistent model if able. - descriptor_index Index to draw descriptors to classify from. - cross_validation - truth_labels Path to a CSV file containing descriptor UUID the truth label associations. This defines what descriptors are used from the given index. We error if any descriptor UUIDs listed here are not available in the given descriptor index. This file should be in [uuid, label] column format. - num_folds Number of folds to make for cross validation. - random_seed Optional fixed seed for the - classification_use_multiprocessing If we should use multiprocessing (vs threading) when classifying elements. - pr_curves - enabled If Precision/Recall plots should be generated. - show If we should attempt to show the graph after it has been generated (matplotlib). - output_directory Directory to save generated plots to. If None, we will not save plots. Otherwise we will create the directory (and required parent directories) if it does not exist. - file_prefix String prefix to prepend to standard plot file names. - roc_curves - enabled If ROC curves should be generated - show If we should attempt to show the plot after it has been generated (matplotlib). - output_directory Directory to save generated plots to. If None, we will not save plots. Otherwise we will create the directory (and required parent directories) if it does not exist. - file_prefix String prefix to prepend to standard plot file names. """ args, config = bin_utils.utility_main_helper(default_config, description) log = logging.getLogger(__name__) # # Load configurations / Setup data # use_mp = config['cross_validation']['classification_use_multiprocessing'] pr_enabled = config['pr_curves']['enabled'] pr_output_dir = config['pr_curves']['output_directory'] pr_file_prefix = config['pr_curves']['file_prefix'] or '' pr_show = config['pr_curves']['show'] roc_enabled = config['roc_curves']['enabled'] roc_output_dir = config['roc_curves']['output_directory'] roc_file_prefix = config['roc_curves']['file_prefix'] or '' roc_show = config['roc_curves']['show'] log.info("Initializing DescriptorIndex (%s)", config['plugins']['descriptor_index']['type']) #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) log.info("Loading classifier configuration") #: :type: dict classifier_config = config['plugins']['supervised_classifier'] classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) log.info("Loading truth data") #: :type: list[str] uuids = [] #: :type: list[str] truth_labels = [] with open(config['cross_validation']['truth_labels']) as f: f_csv = csv.reader(f) for row in f_csv: uuids.append(row[0]) truth_labels.append(row[1]) #: :type: numpy.ndarray[str] uuids = numpy.array(uuids) #: :type: numpy.ndarray[str] truth_labels = numpy.array(truth_labels) # # Cross validation # kfolds = sklearn.cross_validation.StratifiedKFold( truth_labels, config['cross_validation']['num_folds'], random_state=config['cross_validation']['random_seed'] ) """ Truth and classification probability results for test data per fold. Format: { 0: { '<label>': { "truth": [...], # Parallel truth and classification "proba": [...], # probability values }, ... }, ... } """ fold_data = {} i = 0 for train, test in kfolds: log.info("Fold %d", i) log.info("-- %d training examples", len(train)) log.info("-- %d test examples", len(test)) fold_data[i] = {} log.info("-- creating classifier") #: :type: SupervisedClassifier classifier = plugin.from_plugin_config( classifier_config, get_supervised_classifier_impls() ) log.info("-- gathering descriptors") #: :type: dict[str, list[smqtk.representation.DescriptorElement]] pos_map = {} for idx in train: if truth_labels[idx] not in pos_map: pos_map[truth_labels[idx]] = [] pos_map[truth_labels[idx]].append( descriptor_index.get_descriptor(uuids[idx]) ) log.info("-- Training classifier") classifier.train(pos_map) log.info("-- Classifying test set") m = classifier.classify_async( (descriptor_index.get_descriptor(uuids[idx]) for idx in test), classification_factory, use_multiprocessing=use_mp, ri=1.0 ) uuid2c = dict((d.uuid(), c.get_classification()) for d, c in m.iteritems()) log.info("-- Pairing truth and computed probabilities") # Only considering positive labels for t_label in pos_map: fold_data[i][t_label] = { "truth": [l == t_label for l in truth_labels[test]], "proba": [uuid2c[uuid][t_label] for uuid in uuids[test]] } i += 1 # # Curve generation # if pr_enabled: make_pr_curves(fold_data, pr_output_dir, pr_file_prefix, pr_show) if roc_enabled: make_roc_curves(fold_data, roc_output_dir, roc_file_prefix, roc_show)
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # - parallel_map UUIDs to load from the configured index # - classify iterated descriptors uuids_list_filepath = args.uuids_list output_csv_filepath = args.csv_data output_csv_header_filepath = args.csv_header classify_overwrite = config['utility']['classify_overwrite'] p_use_multiprocessing = \ config['utility']['parallel']['use_multiprocessing'] p_index_extraction_cores = \ config['utility']['parallel']['index_extraction_cores'] p_classification_cores = \ config['utility']['parallel']['classification_cores'] if not uuids_list_filepath: raise ValueError("No uuids_list_filepath specified.") elif not os.path.isfile(uuids_list_filepath): raise ValueError("Given uuids_list_filepath did not point to a file.") if output_csv_header_filepath is None: raise ValueError("Need a path to save CSV header labels") if output_csv_filepath is None: raise ValueError("Need a path to save CSV data.") # # Initialize configured plugins # log.info("Initializing descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls()) log.info("Initializing classification factory") c_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory']) log.info("Initializing classifier") #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config(config['plugins']['classifier'], get_classifier_impls()) # # Setup/Process # def iter_uuids(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() def descr_for_uuid(uuid): """ :type uuid: collections.Hashable :rtype: smqtk.representation.DescriptorElement """ return descriptor_index.get_descriptor(uuid) def classify_descr(d): """ :type d: smqtk.representation.DescriptorElement :rtype: smqtk.representation.ClassificationElement """ return classifier.classify(d, c_factory, classify_overwrite) log.info("Initializing uuid-to-descriptor parallel map") #: :type: collections.Iterable[smqtk.representation.DescriptorElement] element_iter = parallel.parallel_map( descr_for_uuid, iter_uuids(), use_multiprocessing=p_use_multiprocessing, cores=p_index_extraction_cores, name="descr_for_uuid", ) log.info("Initializing descriptor-to-classification parallel map") #: :type: collections.Iterable[smqtk.representation.ClassificationElement] classification_iter = parallel.parallel_map( classify_descr, element_iter, use_multiprocessing=p_use_multiprocessing, cores=p_classification_cores, name='classify_descr', ) # # Write/Output files # c_labels = classifier.get_labels() def make_row(c): """ :type c: smqtk.representation.ClassificationElement """ c_m = c.get_classification() return [c.uuid] + [c_m[l] for l in c_labels] # column labels file log.info("Writing CSV column header file: %s", output_csv_header_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath)) with open(output_csv_header_filepath, 'wb') as f_csv: w = csv.writer(f_csv) w.writerow(['uuid'] + c_labels) # CSV file log.info("Writing CSV data file: %s", output_csv_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_filepath)) r_state = [0] * 7 with open(output_csv_filepath, 'wb') as f_csv: w = csv.writer(f_csv) for c in classification_iter: w.writerow(make_row(c)) bin_utils.report_progress(log.info, r_state, 1.0) # Final report r_state[1] -= 1 bin_utils.report_progress(log.info, r_state, 0) log.info("Done")
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # - parallel_map UUIDs to load from the configured index # - classify iterated descriptors uuids_list_filepath = args.uuids_list output_csv_filepath = args.csv_data output_csv_header_filepath = args.csv_header classify_overwrite = config['utility']['classify_overwrite'] p_use_multiprocessing = \ config['utility']['parallel']['use_multiprocessing'] p_index_extraction_cores = \ config['utility']['parallel']['index_extraction_cores'] p_classification_cores = \ config['utility']['parallel']['classification_cores'] if not uuids_list_filepath: raise ValueError("No uuids_list_filepath specified.") elif not os.path.isfile(uuids_list_filepath): raise ValueError("Given uuids_list_filepath did not point to a file.") if output_csv_header_filepath is None: raise ValueError("Need a path to save CSV header labels") if output_csv_filepath is None: raise ValueError("Need a path to save CSV data.") # # Initialize configured plugins # log.info("Initializing descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) log.info("Initializing classification factory") c_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) log.info("Initializing classifier") #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config( config['plugins']['classifier'], get_classifier_impls() ) # # Setup/Process # def iter_uuids(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() def descr_for_uuid(uuid): """ :type uuid: collections.Hashable :rtype: smqtk.representation.DescriptorElement """ return descriptor_index.get_descriptor(uuid) def classify_descr(d): """ :type d: smqtk.representation.DescriptorElement :rtype: smqtk.representation.ClassificationElement """ return classifier.classify(d, c_factory, classify_overwrite) log.info("Initializing uuid-to-descriptor parallel map") #: :type: collections.Iterable[smqtk.representation.DescriptorElement] element_iter = parallel.parallel_map( descr_for_uuid, iter_uuids(), use_multiprocessing=p_use_multiprocessing, cores=p_index_extraction_cores, name="descr_for_uuid", ) log.info("Initializing descriptor-to-classification parallel map") #: :type: collections.Iterable[smqtk.representation.ClassificationElement] classification_iter = parallel.parallel_map( classify_descr, element_iter, use_multiprocessing=p_use_multiprocessing, cores=p_classification_cores, name='classify_descr', ) # # Write/Output files # c_labels = classifier.get_labels() def make_row(e): """ :type e: smqtk.representation.ClassificationElement """ c_m = e.get_classification() return [e.uuid] + [c_m[l] for l in c_labels] # column labels file log.info("Writing CSV column header file: %s", output_csv_header_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath)) with open(output_csv_header_filepath, 'wb') as f_csv: w = csv.writer(f_csv) w.writerow(['uuid'] + [str(cl) for cl in c_labels]) # CSV file log.info("Writing CSV data file: %s", output_csv_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_filepath)) r_state = [0] * 7 with open(output_csv_filepath, 'wb') as f_csv: w = csv.writer(f_csv) for c in classification_iter: w.writerow(make_row(c)) bin_utils.report_progress(log.info, r_state, 1.0) # Final report r_state[1] -= 1 bin_utils.report_progress(log.info, r_state, 0) log.info("Done")
def __init__(self, json_config): """ Initialize application based of supplied JSON configuration :param json_config: JSON configuration dictionary :type json_config: dict """ super(NearestNeighborServiceServer, self).__init__(json_config) self.update_index = json_config['update_descriptor_index'] # Descriptor factory setup self._log.info("Initializing DescriptorElementFactory") self.descr_elem_factory = DescriptorElementFactory.from_config( self.json_config['descriptor_factory'] ) #: :type: smqtk.representation.DescriptorIndex | None self.descr_index = None if self.update_index: self._log.info("Initializing DescriptorIndex to update") #: :type: smqtk.representation.DescriptorIndex | None self.descr_index = plugin.from_plugin_config( json_config['descriptor_index'], get_descriptor_index_impls() ) #: :type: smqtk.algorithms.NearestNeighborsIndex self.nn_index = plugin.from_plugin_config( json_config['nn_index'], get_nn_index_impls() ) #: :type: smqtk.algorithms.DescriptorGenerator self.descriptor_generator_inst = plugin.from_plugin_config( self.json_config['descriptor_generator'], get_descriptor_generator_impls() ) @self.route("/count", methods=['GET']) def count(): """ Return the number of elements represented in this index. """ return flask.jsonify(**{ "count": self.nn_index.count(), }) @self.route("/compute/<path:uri>", methods=["POST"]) def compute(uri): """ Compute the descriptor for a URI specified data element using the configured descriptor generator. If the a descriptor index was configured and update was turned on, we add the computed descriptor to the index. JSON Return format:: { "success": <bool> "message": <str> "descriptor": <None|list[float]> "reference_uri": <str> } :param uri: URI data specification. """ descriptor = None try: _, descriptor = self.generate_descriptor_for_uri(uri) message = "Descriptor generated" descriptor = map(float, descriptor.vector()) except ValueError, ex: message = "Input value issue: %s" % str(ex) except RuntimeError, ex: message = "Descriptor extraction failure: %s" % str(ex)
def run_file_list(c, filelist_filepath, checkpoint_filepath, batch_size=None, check_image=False): """ Top level function handling configuration and inputs/outputs. :param c: Configuration dictionary (JSON) :type c: dict :param filelist_filepath: Path to a text file that lists paths to data files, separated by new lines. :type filelist_filepath: str :param checkpoint_filepath: Output file to which we write input filepath to SHA1 (UUID) relationships. :type checkpoint_filepath: :param batch_size: Optional batch size (None default) of data elements to process / descriptors to compute at a time. This causes files and stores to be written to incrementally during processing instead of one single batch transaction at a time. :type batch_size: :param check_image: Enable checking image loading from file before queueing that file for processing. If the check fails, the file is skipped instead of a halting exception being raised. :type check_image: bool """ log = logging.getLogger(__name__) file_paths = [l.strip() for l in open(filelist_filepath)] log.info("Making descriptor factory") factory = DescriptorElementFactory.from_config(c['descriptor_factory']) log.info("Making descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config(c['descriptor_index'], get_descriptor_index_impls()) data_set = None if c['optional_data_set']['type'] is None: log.info("Not saving loaded data elements to data set") else: log.info("Initializing data set to append to") #: :type: smqtk.representation.DataSet data_set = plugin.from_plugin_config(c['optional_data_set'], get_data_set_impls()) log.info("Making descriptor generator '%s'", c['descriptor_generator']['type']) #: :type: smqtk.algorithms.DescriptorGenerator generator = plugin.from_plugin_config(c['descriptor_generator'], get_descriptor_generator_impls()) def iter_valid_elements(): def is_valid(file_path): dfe = DataFileElement(file_path) if is_valid_element( dfe, valid_content_types=generator.valid_content_types(), check_image=check_image): return dfe else: return False data_elements = collections.deque() valid_files_filter = parallel.parallel_map(is_valid, file_paths, name="check-file-type", use_multiprocessing=True) for dfe in valid_files_filter: if dfe: yield dfe if data_set is not None: data_elements.append(dfe) if batch_size and len(data_elements) == batch_size: log.debug( "Adding data element batch to set (size: %d)", len(data_elements)) data_set.add_data(*data_elements) data_elements.clear() # elements only collected if we have a data-set configured, so add any # still in the deque to the set if data_elements: log.debug("Adding data elements to set (size: %d", len(data_elements)) data_set.add_data(*data_elements) log.info("Computing descriptors") m = compute_many_descriptors( iter_valid_elements(), generator, factory, descriptor_index, batch_size=batch_size, ) # Recording computed file paths and associated file UUIDs (SHA1) cf = open(checkpoint_filepath, 'w') cf_writer = csv.writer(cf) try: rps = [0] * 7 for fp, descr in m: cf_writer.writerow([fp, descr.uuid()]) report_progress(log.debug, rps, 1.) finally: del cf_writer cf.close() log.info("Done")
def __init__(self, json_config): """ Initialize application based of supplied JSON configuration :param json_config: JSON configuration dictionary :type json_config: dict """ super(NearestNeighborServiceServer, self).__init__(json_config) self.update_index = json_config['update_descriptor_index'] # Descriptor factory setup self._log.info("Initializing DescriptorElementFactory") self.descr_elem_factory = DescriptorElementFactory.from_config( self.json_config['descriptor_factory']) #: :type: smqtk.representation.DescriptorIndex | None self.descr_index = None if self.update_index: self._log.info("Initializing DescriptorIndex to update") #: :type: smqtk.representation.DescriptorIndex | None self.descr_index = plugin.from_plugin_config( json_config['descriptor_index'], get_descriptor_index_impls()) #: :type: smqtk.algorithms.NearestNeighborsIndex self.nn_index = plugin.from_plugin_config(json_config['nn_index'], get_nn_index_impls()) #: :type: smqtk.algorithms.DescriptorGenerator self.descriptor_generator_inst = plugin.from_plugin_config( self.json_config['descriptor_generator'], get_descriptor_generator_impls()) @self.route("/count", methods=['GET']) def count(): """ Return the number of elements represented in this index. """ return flask.jsonify(**{ "count": self.nn_index.count(), }) @self.route("/compute/<path:uri>", methods=["POST"]) def compute(uri): """ Compute the descriptor for a URI specified data element using the configured descriptor generator. If the a descriptor index was configured and update was turned on, we add the computed descriptor to the index. JSON Return format:: { "success": <bool> "message": <str> "descriptor": <None|list[float]> "reference_uri": <str> } :param uri: URI data specification. """ descriptor = None try: _, descriptor = self.generate_descriptor_for_uri(uri) message = "Descriptor generated" descriptor = map(float, descriptor.vector()) except ValueError, ex: message = "Input value issue: %s" % str(ex) except RuntimeError, ex: message = "Descriptor extraction failure: %s" % str(ex)
def __init__(self, json_config): """ Initialize application based of supplied JSON configuration :param json_config: JSON configuration dictionary :type json_config: dict """ super(NearestNeighborServiceServer, self).__init__(json_config) self.update_index = json_config['update_descriptor_index'] # Descriptor factory setup self._log.info("Initializing DescriptorElementFactory") self.descr_elem_factory = DescriptorElementFactory.from_config( self.json_config['descriptor_factory'] ) #: :type: smqtk.representation.DescriptorIndex | None self.descr_index = None if self.update_index: self._log.info("Initializing DescriptorIndex to update") #: :type: smqtk.representation.DescriptorIndex | None self.descr_index = plugin.from_plugin_config( json_config['descriptor_index'], get_descriptor_index_impls() ) #: :type: smqtk.algorithms.NearestNeighborsIndex self.nn_index = plugin.from_plugin_config( json_config['nn_index'], get_nn_index_impls() ) #: :type: smqtk.algorithms.DescriptorGenerator self.descriptor_generator_inst = plugin.from_plugin_config( self.json_config['descriptor_generator'], get_descriptor_generator_impls() ) @self.route("/count", methods=['GET']) def count(): """ Return the number of elements represented in this index. """ return flask.jsonify(**{ "count": self.nn_index.count(), }) @self.route("/compute/<path:uri>", methods=["POST"]) def compute(uri): """ Compute the descriptor for a URI specified data element using the configured descriptor generator. See ``compute_nearest_neighbors`` method docstring for URI specifications accepted. If the a descriptor index was configured and update was turned on, we add the computed descriptor to the index. JSON Return format:: { "success": <bool> "message": <str> "descriptor": <None|list[float]> "reference_uri": <str> } :param uri: URI data specification. """ descriptor = None try: descriptor = self.generate_descriptor_for_uri(uri) message = "Descriptor generated" descriptor = list(map(float, descriptor.vector())) except ValueError as ex: message = "Input value issue: %s" % str(ex) except RuntimeError as ex: message = "Descriptor extraction failure: %s" % str(ex) return flask.jsonify( success=descriptor is not None, message=message, descriptor=descriptor, reference_uri=uri, ) @self.route("/nn/<path:uri>") @self.route("/nn/n=<int:n>/<path:uri>") @self.route("/nn/n=<int:n>/<int:start_i>:<int:end_i>/<path:uri>") def compute_nearest_neighbors(uri, n=10, start_i=None, end_i=None): """ Data modes for upload/use: - local filepath - base64 - http/s URL - existing data/descriptor UUID The following sub-sections detail how different URI's can be used. Local Filepath -------------- The URI string must be prefixed with ``file://``, followed by the full path to the data file to describe. Base 64 data ------------ The URI string must be prefixed with "base64://", followed by the base64 encoded string. This mode also requires an additional ``?content_type=`` to provide data content type information. This mode saves the encoded data to temporary file for processing. HTTP/S address -------------- This is the default mode when the URI prefix is none of the above. This uses the requests module to locally download a data file for processing. Existing Data/Descriptor by UUID -------------------------------- When given a uri prefixed with "uuid://", we interpret the remainder of the uri as the UUID of a descriptor already present in the configured descriptor index. If the given UUID is not present in the index, a KeyError is raised. JSON Return format ------------------ { "success": <bool> "message": <str> "neighbors": <None|list[float]> "reference_uri": <str> } :param n: Number of neighbors to query for :param start_i: The starting index of the neighbor vectors to slice into for return. :param end_i: The ending index of the neighbor vectors to slice into for return. :type uri: str """ descriptor = None try: descriptor = self.generate_descriptor_for_uri(uri) message = "descriptor computed" except ValueError as ex: message = "Input data issue: %s" % str(ex) except RuntimeError as ex: message = "Descriptor generation failure: %s" % str(ex) # Base pagination slicing based on provided start and end indices, # otherwise clamp to beginning/ending of queried neighbor sequence. page_slice = slice(start_i or 0, end_i or n) neighbors = [] dists = [] if descriptor is not None: try: neighbors, dists = \ self.nn_index.nn(descriptor, n) except ValueError as ex: message = "Descriptor or index related issue: %s" % str(ex) # TODO: Return the optional descriptor vectors for the neighbors # noinspection PyTypeChecker d = { "success": bool(descriptor is not None), "message": message, "neighbors": [n.uuid() for n in neighbors[page_slice]], "distances": dists[page_slice], "reference_uri": uri } return flask.jsonify(d)
def default_config(): return { "itq_config": ItqFunctor.get_default_config(), "uuids_list_filepath": None, "descriptor_index": plugin.make_config(get_descriptor_index_impls()), }
def main(): args = cli_parser().parse_args() config = bin_utils.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # # Load configuration contents # uuid_list_filepath = args.uuids_list report_interval = config['utility']['report_interval'] use_multiprocessing = config['utility']['use_multiprocessing'] # # Checking input parameters # if (uuid_list_filepath is not None) and \ not os.path.isfile(uuid_list_filepath): raise ValueError("UUIDs list file does not exist!") # # Loading stuff # log.info("Loading descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) log.info("Loading LSH functor") #: :type: smqtk.algorithms.LshFunctor lsh_functor = plugin.from_plugin_config( config['plugins']['lsh_functor'], get_lsh_functor_impls() ) log.info("Loading Key/Value store") #: :type: smqtk.representation.KeyValueStore hash2uuids_kvstore = plugin.from_plugin_config( config['plugins']['hash2uuid_kvstore'], get_key_value_store_impls() ) # Iterate either over what's in the file given, or everything in the # configured index. def iter_uuids(): if uuid_list_filepath: log.info("Using UUIDs list file") with open(uuid_list_filepath) as f: for l in f: yield l.strip() else: log.info("Using all UUIDs resent in descriptor index") for k in descriptor_index.keys(): yield k # # Compute codes # log.info("Starting hash code computation") kv_update = {} for uuid, hash_int in \ compute_hash_codes(uuids_for_processing(iter_uuids(), hash2uuids_kvstore), descriptor_index, lsh_functor, report_interval, use_multiprocessing, True): # Get original value in KV-store if not in update dict. if hash_int not in kv_update: kv_update[hash_int] = hash2uuids_kvstore.get(hash_int, set()) kv_update[hash_int] |= {uuid} if kv_update: log.info("Updating KV store... (%d keys)" % len(kv_update)) hash2uuids_kvstore.add_many(kv_update) log.info("Done")
def main(): description = """ Script for asynchronously computing classifications for DescriptorElements in a DescriptorIndex specified via a list of UUIDs. Results are output to a CSV file in the format: uuid, label1_confidence, label2_confidence, ... CSV columns labels are output to the given CSV header file path. Label columns will be in the order as reported by the classifier implementations ``get_labels`` method. Due to using an input file-list of UUIDs, we require that the UUIDs of indexed descriptors be strings, or equality comparable to the UUIDs' string representation. """ args, config = bin_utils.utility_main_helper( default_config, description, extend_parser, ) log = logging.getLogger(__name__) # - parallel_map UUIDs to load from the configured index # - classify iterated descriptors uuids_list_filepath = args.uuids_list output_csv_filepath = args.csv_data output_csv_header_filepath = args.csv_header classify_overwrite = config['utility']['classify_overwrite'] p_use_multiprocessing = \ config['utility']['parallel']['use_multiprocessing'] p_index_extraction_cores = \ config['utility']['parallel']['index_extraction_cores'] p_classification_cores = \ config['utility']['parallel']['classification_cores'] if not uuids_list_filepath: raise ValueError("No uuids_list_filepath specified.") elif not os.path.isfile(uuids_list_filepath): raise ValueError("Given uuids_list_filepath did not point to a file.") if output_csv_header_filepath is None: raise ValueError("Need a path to save CSV header labels") if output_csv_filepath is None: raise ValueError("Need a path to save CSV data.") # # Initialize configured plugins # log.info("Initializing descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = plugin.from_plugin_config( config['plugins']['descriptor_index'], get_descriptor_index_impls() ) log.info("Initializing classification factory") c_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory'] ) log.info("Initializing classifier") #: :type: smqtk.algorithms.Classifier classifier = plugin.from_plugin_config( config['plugins']['classifier'], get_classifier_impls() ) # # Setup/Process # def iter_uuids(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() def descr_for_uuid(uuid): """ :type uuid: collections.Hashable :rtype: smqtk.representation.DescriptorElement """ return descriptor_index.get_descriptor(uuid) def classify_descr(d): """ :type d: smqtk.representation.DescriptorElement :rtype: smqtk.representation.ClassificationElement """ return classifier.classify(d, c_factory, classify_overwrite) log.info("Initializing uuid-to-descriptor parallel map") #: :type: collections.Iterable[smqtk.representation.DescriptorElement] element_iter = parallel.parallel_map( descr_for_uuid, iter_uuids(), use_multiprocessing=p_use_multiprocessing, cores=p_index_extraction_cores, name="descr_for_uuid", ) log.info("Initializing descriptor-to-classification parallel map") #: :type: collections.Iterable[smqtk.representation.ClassificationElement] classification_iter = parallel.parallel_map( classify_descr, element_iter, use_multiprocessing=p_use_multiprocessing, cores=p_classification_cores, name='classify_descr', ) # # Write/Output files # c_labels = classifier.get_labels() def make_row(c): """ :type c: smqtk.representation.ClassificationElement """ c_m = c.get_classification() return [c.uuid] + [c_m[l] for l in c_labels] # column labels file log.info("Writing CSV column header file: %s", output_csv_header_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath)) with open(output_csv_header_filepath, 'wb') as f_csv: w = csv.writer(f_csv) w.writerow(['uuid'] + c_labels) # CSV file log.info("Writing CSV data file: %s", output_csv_filepath) file_utils.safe_create_dir(os.path.dirname(output_csv_filepath)) r_state = [0] * 7 with open(output_csv_filepath, 'wb') as f_csv: w = csv.writer(f_csv) for c in classification_iter: w.writerow(make_row(c)) bin_utils.report_progress(log.info, r_state, 1.0) # Final report r_state[1] -= 1 bin_utils.report_progress(log.info, r_state, 0) log.info("Done")