def main(): parser = cli_parser() args = parser.parse_args() config = cli.utility_main_helper(default_config, args) log = logging.getLogger(__name__) output_filepath = args.output_filepath overwrite = args.overwrite if not args.input_file: log.error("Failed to provide an input file path") exit(1) elif not os.path.isfile(args.input_file): log.error("Given path does not point to a file.") exit(1) input_filepath = args.input_file data_element = DataFileElement(input_filepath) factory = DescriptorElementFactory.from_config(config['descriptor_factory']) #: :type: smqtk.algorithms.descriptor_generator.DescriptorGenerator cd = from_config_dict(config['content_descriptor'], DescriptorGenerator.get_impls()) vec = generate_vector(log, cd, data_element, factory, overwrite) if output_filepath: numpy.save(output_filepath, vec) else: # Construct string, because numpy s = [] # noinspection PyTypeChecker for f in vec: s.append('%15f' % f) print(' '.join(s))
def main(): args = cli_parser().parse_args() config = cli.utility_main_helper(default_config, args) log = logging.getLogger(__name__) paths_file = args.paths_file after_time = args.after_time before_time = args.before_time # # Check dir/file locations # if paths_file is None: raise ValueError("Need a file path to to output transferred file " "paths!") safe_create_dir(os.path.dirname(paths_file)) # # Start collection # remote_paths = solr_image_paths(config['solr_address'], after_time or '*', before_time or '*', config['solr_username'], config['solr_password'], config['batch_size']) log.info("Writing file paths") with open(paths_file, 'w') as of: pr = cli.ProgressReporter(log.info, 1.0).start() for rp in remote_paths: of.write(rp + '\n') pr.increment_report() pr.report()
def main(): args = cli_parser().parse_args() config = utility_main_helper(default_config, args) log = logging.getLogger(__name__) completed_files_fp = args.completed_files filelist_fp = args.file_list batch_size = args.batch_size check_image = args.check_image # Input checking if not filelist_fp: log.error("No file-list file specified") exit(102) elif not os.path.isfile(filelist_fp): log.error("Invalid file list path: %s", filelist_fp) exit(103) if not completed_files_fp: log.error("No complete files output specified") exit(104) if batch_size < 0: log.error("Batch size must be >= 0.") exit(105) run_file_list(config, filelist_fp, completed_files_fp, batch_size, check_image)
def main(): args = cli_parser().parse_args() config = utility_main_helper(default_config, args) log = logging.getLogger(__name__) output_filepath = args.output_map if not output_filepath: raise ValueError("No path given for output map file (pickle).") #: :type: smqtk.representation.DescriptorIndex index = from_config_dict(config['descriptor_index'], DescriptorIndex.get_impls()) mbkm = MiniBatchKMeans(verbose=args.verbose, compute_labels=False, **config['minibatch_kmeans_params']) initial_fit_size = int(config['initial_fit_size']) d_classes = mb_kmeans_build_apply(index, mbkm, initial_fit_size) log.info("Saving KMeans centroids to: %s", config['centroids_output_filepath_npy']) numpy.save(config['centroids_output_filepath_npy'], mbkm.cluster_centers_) log.info("Saving result classification map to: %s", output_filepath) safe_create_dir(os.path.dirname(output_filepath)) with open(output_filepath, 'w') as f: cPickle.dump(d_classes, f, -1) log.info("Done")
def main(): args = get_cli_parser().parse_args() config = utility_main_helper(get_default_config, args) log = logging.getLogger(__name__) log.debug("Showing debug messages.") iqr_state_fp = args.iqr_state if not os.path.isfile(iqr_state_fp): log.error("IQR Session info JSON filepath was invalid") exit(102) train_classifier_iqr(config, iqr_state_fp)
def main(): parser = cli_parser() args = parser.parse_args() # Default config options for this util are technically valid for running, # its just a bad authkey. config = cli.utility_main_helper(default_config, args, default_config_valid=True) port = int(config['port']) authkey = bytes(config['authkey']) mgr = ProxyManager(('', port), authkey) mgr.get_server().serve_forever()
def main(): # Print help and exit if no arguments were passed if len(sys.argv) == 1: get_cli_parser().print_help() sys.exit(1) args = get_cli_parser().parse_args() config = utility_main_helper(get_default_config, args) log = logging.getLogger(__name__) log.debug('Showing debug messages.') #: :type: smqtk.representation.DescriptorIndex descriptor_set = from_config_dict(config['plugins']['descriptor_set'], DescriptorIndex.get_impls()) #: :type: smqtk.algorithms.NearestNeighborsIndex nearest_neighbor_index = from_config_dict( config['plugins']['nn_index'], NearestNeighborsIndex.get_impls()) # noinspection PyShadowingNames def nearest_neighbors(descriptor, n): if n == 0: n = len(nearest_neighbor_index) uuids, descriptors = nearest_neighbor_index.nn(descriptor, n) # Strip first result (itself) and create list of (uuid, distance) return list(zip([x.uuid() for x in uuids[1:]], descriptors[1:])) if args.uuid_list is not None and not os.path.exists(args.uuid_list): log.error('Invalid file list path: %s', args.uuid_list) exit(103) elif args.num < 0: log.error('Number of nearest neighbors must be >= 0') exit(105) if args.uuid_list is not None: with open(args.uuid_list, 'r') as infile: for line in infile: descriptor = descriptor_set.get_descriptor(line.strip()) print(descriptor.uuid()) for neighbor in nearest_neighbors(descriptor, args.num): print('%s,%f' % neighbor) else: for (uuid, descriptor) in descriptor_set.iteritems(): print(uuid) for neighbor in nearest_neighbors(descriptor, args.num): print('%s,%f' % neighbor)
def main(): args = cli_parser().parse_args() config = cli.utility_main_helper(default_config, args) log = logging.getLogger(__name__) api_root = config['tool']['girder_api_root'] api_key = config['tool']['api_key'] api_query_batch = config['tool']['api_query_batch'] insert_batch_size = config['tool']['dataset_insert_batch_size'] # Collect N folder/item/file references on CL and any files referenced. #: :type: list[str] ids_folder = args.folder #: :type: list[str] ids_item = args.item #: :type: list[str] ids_file = args.file if args.folder_list: with open(args.folder_list) as f: ids_folder.extend([fid.strip() for fid in f]) if args.item_list: with open(args.item_list) as f: ids_item.extend([iid.strip() for iid in f]) if args.file_list: with open(args.file_list) as f: ids_file.extend([fid.strip() for fid in f]) #: :type: smqtk.representation.DataSet data_set = from_config_dict(config['plugins']['data_set'], DataSet.get_impls()) batch = collections.deque() pr = cli.ProgressReporter(log.info, 1.0).start() for e in find_girder_files(api_root, ids_folder, ids_item, ids_file, api_key, api_query_batch): batch.append(e) if insert_batch_size and len(batch) >= insert_batch_size: data_set.add_data(*batch) batch.clear() pr.increment_report() pr.report() if batch: data_set.add_data(*batch) log.info('Done')
def main(): parser = cli_parser() args = parser.parse_args() config = cli.utility_main_helper(default_config, args) log = logging.getLogger(__name__) log.debug("Script arguments:\n%s" % args) def iter_input_elements(): for f in args.input_files: f = osp.expanduser(f) if osp.isfile(f): yield DataFileElement(f) else: log.debug("Expanding glob: %s" % f) for g in glob.glob(f): yield DataFileElement(g) log.info("Adding elements to data set") ds = from_config_dict(config['data_set'], DataSet.get_impls()) ds.add_data(*iter_input_elements())
def main(): args = cli_parser().parse_args() config = cli.utility_main_helper(default_config, args) log = logging.getLogger(__name__) #: :type: smqtk.representation.DescriptorIndex descriptor_index = from_config_dict(config['plugins']['descriptor_index'], DescriptorIndex.get_impls()) labels_filepath = args.f output_filepath = args.o # Run through labeled UUIDs in input file, getting the descriptor from the # configured index, applying the appropriate integer label and then writing # the formatted line out to the output file. input_uuid_labels = csv.reader(open(labels_filepath)) with open(output_filepath, 'w') as ofile: label2int = {} next_int = 1 uuids, labels = list(zip(*input_uuid_labels)) log.info("Scanning input descriptors and labels") for i, (l, d) in enumerate( zip(labels, descriptor_index.get_many_descriptors(uuids))): log.debug("%d %s", i, d.uuid()) if l not in label2int: label2int[l] = next_int next_int += 1 ofile.write("%d " % label2int[l] + ' '.join([ "%d:%.12f" % (j + 1, f) for j, f in enumerate(d.vector()) if f != 0.0 ]) + '\n') log.info("Integer label association:") for i, l in sorted((i, l) for l, i in six.iteritems(label2int)): log.info('\t%d :: %s', i, l)
def main(): args = cli_parser().parse_args() config = cli.utility_main_helper(default_config, args) log = logging.getLogger(__name__) uuids_list_filepath = config['uuids_list_filepath'] log.info("Initializing ITQ functor") #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor functor = ItqFunctor.from_config(config['itq_config']) log.info("Initializing DescriptorSet [type=%s]", config['descriptor_set']['type']) #: :type: smqtk.representation.DescriptorSet descriptor_set = from_config_dict( config['descriptor_set'], DescriptorSet.get_impls(), ) if uuids_list_filepath and os.path.isfile(uuids_list_filepath): def uuids_iter(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() log.info("Loading UUIDs list from file: %s", uuids_list_filepath) d_iter = descriptor_set.get_many_descriptors(uuids_iter()) else: log.info("Using UUIDs from loaded DescriptorSet (count=%d)", len(descriptor_set)) d_iter = descriptor_set log.info("Fitting ITQ model") functor.fit(d_iter) log.info("Done")
def main(): args = cli_parser().parse_args() config = cli.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # - parallel_map UUIDs to load from the configured index # - classify iterated descriptors uuids_list_filepath = args.uuids_list output_csv_filepath = args.csv_data output_csv_header_filepath = args.csv_header classify_overwrite = config['utility']['classify_overwrite'] p_use_multiprocessing = \ config['utility']['parallel']['use_multiprocessing'] p_index_extraction_cores = \ config['utility']['parallel']['index_extraction_cores'] p_classification_cores = \ config['utility']['parallel']['classification_cores'] if not uuids_list_filepath: raise ValueError("No uuids_list_filepath specified.") elif not os.path.isfile(uuids_list_filepath): raise ValueError("Given uuids_list_filepath did not point to a file.") if output_csv_header_filepath is None: raise ValueError("Need a path to save CSV header labels") if output_csv_filepath is None: raise ValueError("Need a path to save CSV data.") # # Initialize configured plugins # log.info("Initializing descriptor index") #: :type: smqtk.representation.DescriptorSet descriptor_set = from_config_dict(config['plugins']['descriptor_set'], DescriptorSet.get_impls()) log.info("Initializing classification factory") c_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory']) log.info("Initializing classifier") #: :type: smqtk.algorithms.Classifier classifier = from_config_dict(config['plugins']['classifier'], Classifier.get_impls()) # # Setup/Process # def iter_uuids(): with open(uuids_list_filepath) as f: for l in f: yield l.strip() def descr_for_uuid(uuid): """ :type uuid: collections.Hashable :rtype: smqtk.representation.DescriptorElement """ return descriptor_set.get_descriptor(uuid) def classify_descr(d): """ :type d: smqtk.representation.DescriptorElement :rtype: smqtk.representation.ClassificationElement """ return classifier.classify_one_element(d, c_factory, classify_overwrite) log.info("Initializing uuid-to-descriptor parallel map") #: :type: collections.Iterable[smqtk.representation.DescriptorElement] element_iter = parallel.parallel_map( descr_for_uuid, iter_uuids(), use_multiprocessing=p_use_multiprocessing, cores=p_index_extraction_cores, name="descr_for_uuid", ) log.info("Initializing descriptor-to-classification parallel map") #: :type: collections.Iterable[smqtk.representation.ClassificationElement] classification_iter = parallel.parallel_map( classify_descr, element_iter, use_multiprocessing=p_use_multiprocessing, cores=p_classification_cores, name='classify_descr', ) # # Write/Output files # c_labels = classifier.get_labels() def make_row(e): """ :type e: smqtk.representation.ClassificationElement """ c_m = e.get_classification() return [e.uuid] + [c_m[l] for l in c_labels] # column labels file log.info("Writing CSV column header file: %s", output_csv_header_filepath) safe_create_dir(os.path.dirname(output_csv_header_filepath)) with open(output_csv_header_filepath, 'wb') as f_csv: w = csv.writer(f_csv) w.writerow(['uuid'] + [str(cl) for cl in c_labels]) # CSV file log.info("Writing CSV data file: %s", output_csv_filepath) safe_create_dir(os.path.dirname(output_csv_filepath)) pr = cli.ProgressReporter(log.info, 1.0) pr.start() with open(output_csv_filepath, 'wb') as f_csv: w = csv.writer(f_csv) for c in classification_iter: w.writerow(make_row(c)) pr.increment_report() pr.report() log.info("Done")
def main(): parser = cli_parser() args = parser.parse_args() debug_smqtk = args.debug_smqtk or args.verbose debug_server = args.debug_server or args.verbose debug_app = args.debug_app or args.verbose debug_ns_list = args.debug_ns debug_smqtk and debug_ns_list.append('smqtk') debug_server and debug_ns_list.append('werkzeug') # Create a single stream handler on the root, the level passed being # applied to the handler, and then set tuned levels on specific namespace # levels under root, which is reset to warning. cli.initialize_logging(logging.getLogger(), logging.DEBUG) logging.getLogger().setLevel(logging.WARN) log = logging.getLogger(__name__) # SMQTK level always at least INFO level for standard internals reporting. logging.getLogger("smqtk").setLevel(logging.INFO) # Enable DEBUG level on applicable namespaces available to us at this time. for ns in debug_ns_list: log.info("Enabling debug logging on '{}' namespace".format(ns)) logging.getLogger(ns).setLevel(logging.DEBUG) webapp_types = smqtk.web.SmqtkWebApp.get_impls() web_applications = {t.__name__: t for t in webapp_types} if args.list: log.info("") log.info("Available applications:") log.info("") for l, cls in six.iteritems(web_applications): log.info("\t" + l) if debug_smqtk: log.info('\t' + ('^' * len(l)) + '\n' + cls.__doc__ + '\n' + ('*' * 80) + '\n') log.info("") exit(0) application_name = args.application if application_name is None: log.error("No application name given!") exit(1) elif application_name not in web_applications: log.error("Invalid application label '%s'", application_name) exit(1) #: :type: smqtk.web.SmqtkWebApp app_class = web_applications[application_name] # If the application class's logger does not already report as having INFO/ # DEBUG level logging (due to being a child of an above handled namespace) # then set the app namespace's logger level appropriately app_class_logger_level = app_class.get_logger().getEffectiveLevel() app_class_target_level = logging.INFO - (10 * debug_app) if app_class_logger_level > app_class_target_level: level_name = \ "DEBUG" if app_class_target_level == logging.DEBUG else "INFO" log.info("Enabling '{}' logging for '{}' logger namespace.".format( level_name, app_class.get_logger().name)) app_class.get_logger().setLevel(logging.INFO - (10 * debug_app)) config = cli.utility_main_helper(app_class.get_default_config, args, skip_logging_init=True) host = args.host port = args.port and int(args.port) use_reloader = args.reload use_threading = args.threaded use_basic_auth = args.use_basic_auth use_simple_cors = args.use_simple_cors # noinspection PyUnresolvedReferences #: :type: smqtk.web.SmqtkWebApp app = app_class.from_config(config) if use_basic_auth: app.config["BASIC_AUTH_FORCE"] = True BasicAuth(app) if use_simple_cors: log.debug("Enabling CORS for all domains on all routes.") CORS(app) app.config['DEBUG'] = debug_server log.info("Starting application") app.run(host=host, port=port, debug=debug_server, use_reloader=use_reloader, threaded=use_threading)
def classifier_kfold_validation(): args = cli_parser().parse_args() config = cli.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # # Load configurations / Setup data # pr_enabled = config['pr_curves']['enabled'] pr_output_dir = config['pr_curves']['output_directory'] pr_file_prefix = config['pr_curves']['file_prefix'] or '' pr_show = config['pr_curves']['show'] roc_enabled = config['roc_curves']['enabled'] roc_output_dir = config['roc_curves']['output_directory'] roc_file_prefix = config['roc_curves']['file_prefix'] or '' roc_show = config['roc_curves']['show'] log.info("Initializing DescriptorSet (%s)", config['plugins']['descriptor_set']['type']) #: :type: smqtk.representation.DescriptorSet descriptor_set = from_config_dict(config['plugins']['descriptor_set'], DescriptorSet.get_impls()) log.info("Loading classifier configuration") #: :type: dict classifier_config = config['plugins']['supervised_classifier'] # Always use in-memory ClassificationElement since we are retraining the # classifier and don't want possible element caching #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory( MemoryClassificationElement, {}) log.info("Loading truth data") #: :type: list[str] uuids = [] #: :type: list[str] truth_labels = [] with open(config['cross_validation']['truth_labels']) as f: f_csv = csv.reader(f) for row in f_csv: uuids.append(row[0]) truth_labels.append(row[1]) #: :type: numpy.ndarray[str] uuids = numpy.array(uuids) #: :type: numpy.ndarray[str] truth_labels = numpy.array(truth_labels) # # Cross validation # kfolds = sklearn.model_selection.StratifiedKFold( n_splits=config['cross_validation']['num_folds'], shuffle=True, random_state=config['cross_validation']['random_seed'], ).split(numpy.zeros(len(truth_labels)), truth_labels) """ Truth and classification probability results for test data per fold. Format: { 0: { '<label>': { "truth": [...], # Parallel truth and classification "proba": [...], # probability values }, ... }, ... } """ fold_data: Dict[int, Any] = {} i = 0 for train, test in kfolds: log.info("Fold %d", i) log.info("-- %d training examples", len(train)) log.info("-- %d test examples", len(test)) fold_data[i] = {} log.info("-- creating classifier") classifier = cast( SupervisedClassifier, from_config_dict(classifier_config, SupervisedClassifier.get_impls())) log.info("-- gathering descriptors") pos_map: Dict[str, List[DescriptorElement]] = {} for idx in train: if truth_labels[idx] not in pos_map: pos_map[truth_labels[idx]] = [] pos_map[truth_labels[idx]].append( descriptor_set.get_descriptor(uuids[idx])) log.info("-- Training classifier") classifier.train(pos_map) log.info("-- Classifying test set") c_iter = classifier.classify_elements( (descriptor_set.get_descriptor(uuids[idx]) for idx in test), classification_factory, ) uuid2c = dict((c.uuid, c.get_classification()) for c in c_iter) log.info("-- Pairing truth and computed probabilities") # Only considering positive labels for t_label in pos_map: fold_data[i][t_label] = { "truth": [L == t_label for L in truth_labels[test]], "proba": [uuid2c[uuid][t_label] for uuid in uuids[test]] } i += 1 # # Curve generation # if pr_enabled: make_pr_curves(fold_data, pr_output_dir, pr_file_prefix, pr_show) if roc_enabled: make_roc_curves(fold_data, roc_output_dir, roc_file_prefix, roc_show)
def main(): args = cli_parser().parse_args() config = utility_main_helper(default_config, args) log = logging.getLogger(__name__) report_size = args.report_size crawled_after = args.crawled_after inserted_after = args.inserted_after # # Check config properties # m = mimetypes.MimeTypes() # non-strict types (see use of ``guess_extension`` above) m_img_types = set(m.types_map_inv[0].keys() + m.types_map_inv[1].keys()) if not isinstance(config['image_types'], list): raise ValueError("The 'image_types' property was not set to a list.") for t in config['image_types']: if ('image/' + t) not in m_img_types: raise ValueError("Image type '%s' is not a valid image MIMETYPE " "sub-type." % t) if not report_size and args.output_dir is None: raise ValueError("Require an output directory!") if not report_size and args.file_list is None: raise ValueError("Require an output CSV file path!") # # Initialize ElasticSearch stuff # es_auth = None if config['elastic_search']['username'] and config['elastic_search'][ 'password']: es_auth = (config['elastic_search']['username'], config['elastic_search']['password']) es = elasticsearch.Elasticsearch( config['elastic_search']['instance_address'], http_auth=es_auth, use_ssl=True, verify_certs=True, ca_certs=certifi.where(), ) # # Query and Run # http_auth = None if config['stored_http_auth']['name'] and config['stored_http_auth'][ 'pass']: http_auth = (config['stored_http_auth']['name'], config['stored_http_auth']['pass']) ts_re = re.compile('(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2})Z') if crawled_after: m = ts_re.match(crawled_after) if m is None: raise ValueError("Given 'crawled-after' timestamp not in correct " "format: '%s'" % crawled_after) crawled_after = datetime.datetime(*[int(e) for e in m.groups()]) if inserted_after: m = ts_re.match(inserted_after) if m is None: raise ValueError("Given 'inserted-after' timestamp not in correct " "format: '%s'" % inserted_after) inserted_after = datetime.datetime(*[int(e) for e in m.groups()]) q = cdr_images_after(es, config['elastic_search']['index'], config['image_types'], crawled_after, inserted_after) log.info("Query Size: %d", q[0:0].execute().hits.total) if report_size: exit(0) fetch_cdr_query_images(q, args.output_dir, args.file_list, cores=int(config['parallel']['cores']), stored_http_auth=http_auth, batch_size=int( config['elastic_search']['batch_size']))
def main(): args = cli_parser().parse_args() config = cli.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # Deprecations if (config.get('parallelism', {}).get('classification_cores', None) is not None): warnings.warn( "Usage of 'classification_cores' is deprecated. " "Classifier parallelism is not defined on a " "per-implementation basis. See classifier " "implementation parameterization.", category=DeprecationWarning) # # Initialize stuff from configuration # #: :type: smqtk.algorithms.Classifier classifier = from_config_dict(config['plugins']['classifier'], SupervisedClassifier.get_impls()) #: :type: ClassificationElementFactory classification_factory = ClassificationElementFactory.from_config( config['plugins']['classification_factory']) #: :type: smqtk.representation.DescriptorSet descriptor_set = from_config_dict(config['plugins']['descriptor_set'], DescriptorSet.get_impls()) uuid2label_filepath = config['utility']['csv_filepath'] do_train = config['utility']['train'] output_uuid_cm = config['utility']['output_uuid_confusion_matrix'] plot_filepath_pr = config['utility']['output_plot_pr'] plot_filepath_roc = config['utility']['output_plot_roc'] plot_filepath_cm = config['utility']['output_plot_confusion_matrix'] plot_ci = config['utility']['curve_confidence_interval'] plot_ci_alpha = config['utility']['curve_confidence_interval_alpha'] # # Construct mapping of label to the DescriptorElement instances for that # described by that label. # log.info("Loading descriptors by UUID") def iter_uuid_label(): """ Iterate through UUIDs in specified file """ with open(uuid2label_filepath) as uuid2label_file: reader = csv.reader(uuid2label_file) for r in reader: # TODO: This will need to be updated to handle multiple labels # per descriptor. yield r[0], r[1] def get_descr(r): """ Fetch descriptors from configured index """ uuid, truth_label = r return truth_label, descriptor_set.get_descriptor(uuid) tlabel_element_iter = parallel.parallel_map( get_descr, iter_uuid_label(), name="cmv_get_descriptors", use_multiprocessing=True, cores=config['parallelism']['descriptor_fetch_cores'], ) # Map of truth labels to descriptors of labeled data #: :type: dict[str, list[smqtk.representation.DescriptorElement]] tlabel2descriptors = {} for tlabel, d in tlabel_element_iter: tlabel2descriptors.setdefault(tlabel, []).append(d) # Train classifier if the one given has a ``train`` method and training # was turned enabled. if do_train: log.info("Training supervised classifier model") classifier.train(tlabel2descriptors) exit(0) # # Apply classifier to descriptors for predictions # # Truth label to predicted classification results #: :type: dict[str, set[smqtk.representation.ClassificationElement]] tlabel2classifications = {} for tlabel, descriptors in six.iteritems(tlabel2descriptors): tlabel2classifications[tlabel] = \ set(classifier.classify_elements(descriptors, classification_factory)) log.info("Truth label counts:") for l in sorted(tlabel2classifications): log.info(" %s :: %d", l, len(tlabel2classifications[l])) # # Confusion Matrix # conf_mat, labels = gen_confusion_matrix(tlabel2classifications) log.info("Confusion_matrix") log_cm(log.info, conf_mat, labels) if plot_filepath_cm: plot_cm(conf_mat, labels, plot_filepath_cm) # Confusion Matrix of descriptor UUIDs to output json if output_uuid_cm: # Top dictionary keys are true labels, inner dictionary keys are UUID # predicted labels. log.info("Computing UUID Confusion Matrix") #: :type: dict[str, dict[collections.Hashable, set]] uuid_cm = {} for tlabel in tlabel2classifications: uuid_cm[tlabel] = collections.defaultdict(set) for c in tlabel2classifications[tlabel]: uuid_cm[tlabel][c.max_label()].add(c.uuid) # convert sets to lists for JSON output. for plabel in uuid_cm[tlabel]: # noinspection PyTypeChecker uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel]) with open(output_uuid_cm, 'w') as f: log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm) json.dump(uuid_cm, f, indent=2, separators=(',', ': ')) # # Create PR/ROC curves via scikit learn tools # if plot_filepath_pr: log.info("Making PR curve") make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci, plot_ci_alpha) if plot_filepath_roc: log.info("Making ROC curve") make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci, plot_ci_alpha)
def main(): args = cli_parser().parse_args() config = cli.utility_main_helper(default_config, args) log = logging.getLogger(__name__) # # Load configuration contents # uuid_list_filepath = args.uuids_list report_interval = config['utility']['report_interval'] use_multiprocessing = config['utility']['use_multiprocessing'] # # Checking input parameters # if (uuid_list_filepath is not None) and \ not os.path.isfile(uuid_list_filepath): raise ValueError("UUIDs list file does not exist!") # # Loading stuff # log.info("Loading descriptor index") #: :type: smqtk.representation.DescriptorIndex descriptor_index = from_config_dict(config['plugins']['descriptor_index'], DescriptorIndex.get_impls()) log.info("Loading LSH functor") #: :type: smqtk.algorithms.LshFunctor lsh_functor = from_config_dict(config['plugins']['lsh_functor'], LshFunctor.get_impls()) log.info("Loading Key/Value store") #: :type: smqtk.representation.KeyValueStore hash2uuids_kvstore = from_config_dict( config['plugins']['hash2uuid_kvstore'], KeyValueStore.get_impls()) # Iterate either over what's in the file given, or everything in the # configured index. def iter_uuids(): if uuid_list_filepath: log.info("Using UUIDs list file") with open(uuid_list_filepath) as f: for l in f: yield l.strip() else: log.info("Using all UUIDs resent in descriptor index") for k in descriptor_index.keys(): yield k # # Compute codes # log.info("Starting hash code computation") kv_update = {} for uuid, hash_int in \ compute_hash_codes(uuids_for_processing(iter_uuids(), hash2uuids_kvstore), descriptor_index, lsh_functor, report_interval, use_multiprocessing, True): # Get original value in KV-store if not in update dict. if hash_int not in kv_update: kv_update[hash_int] = hash2uuids_kvstore.get(hash_int, set()) kv_update[hash_int] |= {uuid} if kv_update: log.info("Updating KV store... (%d keys)" % len(kv_update)) hash2uuids_kvstore.add_many(kv_update) log.info("Done")