Ejemplo n.º 1
0
def main():
    parser = cli_parser()
    args = parser.parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    output_filepath = args.output_filepath
    overwrite = args.overwrite

    if not args.input_file:
        log.error("Failed to provide an input file path")
        exit(1)
    elif not os.path.isfile(args.input_file):
        log.error("Given path does not point to a file.")
        exit(1)

    input_filepath = args.input_file
    data_element = DataFileElement(input_filepath)

    factory = DescriptorElementFactory.from_config(config['descriptor_factory'])
    #: :type: smqtk.algorithms.descriptor_generator.DescriptorGenerator
    cd = from_config_dict(config['content_descriptor'],
                          DescriptorGenerator.get_impls())

    vec = generate_vector(log, cd, data_element, factory, overwrite)

    if output_filepath:
        numpy.save(output_filepath, vec)
    else:
        # Construct string, because numpy
        s = []
        # noinspection PyTypeChecker
        for f in vec:
            s.append('%15f' % f)
        print(' '.join(s))
Ejemplo n.º 2
0
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    paths_file = args.paths_file
    after_time = args.after_time
    before_time = args.before_time

    #
    # Check dir/file locations
    #
    if paths_file is None:
        raise ValueError("Need a file path to to output transferred file "
                         "paths!")

    safe_create_dir(os.path.dirname(paths_file))

    #
    # Start collection
    #
    remote_paths = solr_image_paths(config['solr_address'], after_time or '*',
                                    before_time or '*',
                                    config['solr_username'],
                                    config['solr_password'],
                                    config['batch_size'])

    log.info("Writing file paths")
    with open(paths_file, 'w') as of:
        pr = cli.ProgressReporter(log.info, 1.0).start()
        for rp in remote_paths:
            of.write(rp + '\n')
            pr.increment_report()
        pr.report()
Ejemplo n.º 3
0
def main():
    args = cli_parser().parse_args()
    config = utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    completed_files_fp = args.completed_files
    filelist_fp = args.file_list
    batch_size = args.batch_size
    check_image = args.check_image

    # Input checking
    if not filelist_fp:
        log.error("No file-list file specified")
        exit(102)
    elif not os.path.isfile(filelist_fp):
        log.error("Invalid file list path: %s", filelist_fp)
        exit(103)

    if not completed_files_fp:
        log.error("No complete files output specified")
        exit(104)

    if batch_size < 0:
        log.error("Batch size must be >= 0.")
        exit(105)

    run_file_list(config, filelist_fp, completed_files_fp, batch_size,
                  check_image)
Ejemplo n.º 4
0
def main():
    args = cli_parser().parse_args()
    config = utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    output_filepath = args.output_map
    if not output_filepath:
        raise ValueError("No path given for output map file (pickle).")

    #: :type: smqtk.representation.DescriptorIndex
    index = from_config_dict(config['descriptor_index'],
                             DescriptorIndex.get_impls())
    mbkm = MiniBatchKMeans(verbose=args.verbose,
                           compute_labels=False,
                           **config['minibatch_kmeans_params'])
    initial_fit_size = int(config['initial_fit_size'])

    d_classes = mb_kmeans_build_apply(index, mbkm, initial_fit_size)

    log.info("Saving KMeans centroids to: %s",
             config['centroids_output_filepath_npy'])
    numpy.save(config['centroids_output_filepath_npy'], mbkm.cluster_centers_)

    log.info("Saving result classification map to: %s", output_filepath)
    safe_create_dir(os.path.dirname(output_filepath))
    with open(output_filepath, 'w') as f:
        cPickle.dump(d_classes, f, -1)

    log.info("Done")
Ejemplo n.º 5
0
def main():
    args = get_cli_parser().parse_args()
    config = utility_main_helper(get_default_config, args)

    log = logging.getLogger(__name__)
    log.debug("Showing debug messages.")

    iqr_state_fp = args.iqr_state

    if not os.path.isfile(iqr_state_fp):
        log.error("IQR Session info JSON filepath was invalid")
        exit(102)

    train_classifier_iqr(config, iqr_state_fp)
Ejemplo n.º 6
0
def main():
    parser = cli_parser()
    args = parser.parse_args()

    # Default config options for this util are technically valid for running,
    # its just a bad authkey.
    config = cli.utility_main_helper(default_config, args,
                                     default_config_valid=True)

    port = int(config['port'])
    authkey = bytes(config['authkey'])

    mgr = ProxyManager(('', port), authkey)
    mgr.get_server().serve_forever()
Ejemplo n.º 7
0
def main():
    # Print help and exit if no arguments were passed
    if len(sys.argv) == 1:
        get_cli_parser().print_help()
        sys.exit(1)

    args = get_cli_parser().parse_args()
    config = utility_main_helper(get_default_config, args)

    log = logging.getLogger(__name__)
    log.debug('Showing debug messages.')

    #: :type: smqtk.representation.DescriptorIndex
    descriptor_set = from_config_dict(config['plugins']['descriptor_set'],
                                      DescriptorIndex.get_impls())
    #: :type: smqtk.algorithms.NearestNeighborsIndex
    nearest_neighbor_index = from_config_dict(
        config['plugins']['nn_index'], NearestNeighborsIndex.get_impls())

    # noinspection PyShadowingNames
    def nearest_neighbors(descriptor, n):
        if n == 0:
            n = len(nearest_neighbor_index)

        uuids, descriptors = nearest_neighbor_index.nn(descriptor, n)
        # Strip first result (itself) and create list of (uuid, distance)
        return list(zip([x.uuid() for x in uuids[1:]], descriptors[1:]))

    if args.uuid_list is not None and not os.path.exists(args.uuid_list):
        log.error('Invalid file list path: %s', args.uuid_list)
        exit(103)
    elif args.num < 0:
        log.error('Number of nearest neighbors must be >= 0')
        exit(105)

    if args.uuid_list is not None:
        with open(args.uuid_list, 'r') as infile:
            for line in infile:
                descriptor = descriptor_set.get_descriptor(line.strip())
                print(descriptor.uuid())
                for neighbor in nearest_neighbors(descriptor, args.num):
                    print('%s,%f' % neighbor)
    else:
        for (uuid, descriptor) in descriptor_set.iteritems():
            print(uuid)
            for neighbor in nearest_neighbors(descriptor, args.num):
                print('%s,%f' % neighbor)
Ejemplo n.º 8
0
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    api_root = config['tool']['girder_api_root']
    api_key = config['tool']['api_key']
    api_query_batch = config['tool']['api_query_batch']
    insert_batch_size = config['tool']['dataset_insert_batch_size']

    # Collect N folder/item/file references on CL and any files referenced.
    #: :type: list[str]
    ids_folder = args.folder
    #: :type: list[str]
    ids_item = args.item
    #: :type: list[str]
    ids_file = args.file

    if args.folder_list:
        with open(args.folder_list) as f:
            ids_folder.extend([fid.strip() for fid in f])
    if args.item_list:
        with open(args.item_list) as f:
            ids_item.extend([iid.strip() for iid in f])
    if args.file_list:
        with open(args.file_list) as f:
            ids_file.extend([fid.strip() for fid in f])

    #: :type: smqtk.representation.DataSet
    data_set = from_config_dict(config['plugins']['data_set'],
                                DataSet.get_impls())

    batch = collections.deque()
    pr = cli.ProgressReporter(log.info, 1.0).start()
    for e in find_girder_files(api_root, ids_folder, ids_item, ids_file,
                               api_key, api_query_batch):
        batch.append(e)
        if insert_batch_size and len(batch) >= insert_batch_size:
            data_set.add_data(*batch)
            batch.clear()
        pr.increment_report()
    pr.report()

    if batch:
        data_set.add_data(*batch)

    log.info('Done')
Ejemplo n.º 9
0
def main():
    parser = cli_parser()
    args = parser.parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    log.debug("Script arguments:\n%s" % args)

    def iter_input_elements():
        for f in args.input_files:
            f = osp.expanduser(f)
            if osp.isfile(f):
                yield DataFileElement(f)
            else:
                log.debug("Expanding glob: %s" % f)
                for g in glob.glob(f):
                    yield DataFileElement(g)

    log.info("Adding elements to data set")
    ds = from_config_dict(config['data_set'], DataSet.get_impls())
    ds.add_data(*iter_input_elements())
Ejemplo n.º 10
0
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = from_config_dict(config['plugins']['descriptor_index'],
                                        DescriptorIndex.get_impls())

    labels_filepath = args.f
    output_filepath = args.o

    # Run through labeled UUIDs in input file, getting the descriptor from the
    # configured index, applying the appropriate integer label and then writing
    # the formatted line out to the output file.
    input_uuid_labels = csv.reader(open(labels_filepath))

    with open(output_filepath, 'w') as ofile:
        label2int = {}
        next_int = 1
        uuids, labels = list(zip(*input_uuid_labels))

        log.info("Scanning input descriptors and labels")
        for i, (l, d) in enumerate(
                zip(labels, descriptor_index.get_many_descriptors(uuids))):
            log.debug("%d %s", i, d.uuid())
            if l not in label2int:
                label2int[l] = next_int
                next_int += 1
            ofile.write("%d " % label2int[l] + ' '.join([
                "%d:%.12f" % (j + 1, f)
                for j, f in enumerate(d.vector()) if f != 0.0
            ]) + '\n')

    log.info("Integer label association:")
    for i, l in sorted((i, l) for l, i in six.iteritems(label2int)):
        log.info('\t%d :: %s', i, l)
Ejemplo n.º 11
0
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorSet [type=%s]",
             config['descriptor_set']['type'])
    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(
        config['descriptor_set'],
        DescriptorSet.get_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):

        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()

        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_set.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorSet (count=%d)",
                 len(descriptor_set))
        d_iter = descriptor_set

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
Ejemplo n.º 12
0
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    # - parallel_map UUIDs to load from the configured index
    # - classify iterated descriptors

    uuids_list_filepath = args.uuids_list
    output_csv_filepath = args.csv_data
    output_csv_header_filepath = args.csv_header
    classify_overwrite = config['utility']['classify_overwrite']

    p_use_multiprocessing = \
        config['utility']['parallel']['use_multiprocessing']
    p_index_extraction_cores = \
        config['utility']['parallel']['index_extraction_cores']
    p_classification_cores = \
        config['utility']['parallel']['classification_cores']

    if not uuids_list_filepath:
        raise ValueError("No uuids_list_filepath specified.")
    elif not os.path.isfile(uuids_list_filepath):
        raise ValueError("Given uuids_list_filepath did not point to a file.")
    if output_csv_header_filepath is None:
        raise ValueError("Need a path to save CSV header labels")
    if output_csv_filepath is None:
        raise ValueError("Need a path to save CSV data.")

    #
    # Initialize configured plugins
    #

    log.info("Initializing descriptor index")
    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(config['plugins']['descriptor_set'],
                                      DescriptorSet.get_impls())

    log.info("Initializing classification factory")
    c_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])

    log.info("Initializing classifier")
    #: :type: smqtk.algorithms.Classifier
    classifier = from_config_dict(config['plugins']['classifier'],
                                  Classifier.get_impls())

    #
    # Setup/Process
    #
    def iter_uuids():
        with open(uuids_list_filepath) as f:
            for l in f:
                yield l.strip()

    def descr_for_uuid(uuid):
        """
        :type uuid: collections.Hashable
        :rtype: smqtk.representation.DescriptorElement
        """
        return descriptor_set.get_descriptor(uuid)

    def classify_descr(d):
        """
        :type d: smqtk.representation.DescriptorElement
        :rtype: smqtk.representation.ClassificationElement
        """
        return classifier.classify_one_element(d, c_factory,
                                               classify_overwrite)

    log.info("Initializing uuid-to-descriptor parallel map")
    #: :type: collections.Iterable[smqtk.representation.DescriptorElement]
    element_iter = parallel.parallel_map(
        descr_for_uuid,
        iter_uuids(),
        use_multiprocessing=p_use_multiprocessing,
        cores=p_index_extraction_cores,
        name="descr_for_uuid",
    )

    log.info("Initializing descriptor-to-classification parallel map")
    #: :type: collections.Iterable[smqtk.representation.ClassificationElement]
    classification_iter = parallel.parallel_map(
        classify_descr,
        element_iter,
        use_multiprocessing=p_use_multiprocessing,
        cores=p_classification_cores,
        name='classify_descr',
    )

    #
    # Write/Output files
    #

    c_labels = classifier.get_labels()

    def make_row(e):
        """
        :type e: smqtk.representation.ClassificationElement
        """
        c_m = e.get_classification()
        return [e.uuid] + [c_m[l] for l in c_labels]

    # column labels file
    log.info("Writing CSV column header file: %s", output_csv_header_filepath)
    safe_create_dir(os.path.dirname(output_csv_header_filepath))
    with open(output_csv_header_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        w.writerow(['uuid'] + [str(cl) for cl in c_labels])

    # CSV file
    log.info("Writing CSV data file: %s", output_csv_filepath)
    safe_create_dir(os.path.dirname(output_csv_filepath))
    pr = cli.ProgressReporter(log.info, 1.0)
    pr.start()
    with open(output_csv_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        for c in classification_iter:
            w.writerow(make_row(c))
            pr.increment_report()
        pr.report()

    log.info("Done")
Ejemplo n.º 13
0
def main():
    parser = cli_parser()
    args = parser.parse_args()

    debug_smqtk = args.debug_smqtk or args.verbose
    debug_server = args.debug_server or args.verbose
    debug_app = args.debug_app or args.verbose

    debug_ns_list = args.debug_ns
    debug_smqtk and debug_ns_list.append('smqtk')
    debug_server and debug_ns_list.append('werkzeug')

    # Create a single stream handler on the root, the level passed being
    # applied to the handler, and then set tuned levels on specific namespace
    # levels under root, which is reset to warning.
    cli.initialize_logging(logging.getLogger(), logging.DEBUG)
    logging.getLogger().setLevel(logging.WARN)
    log = logging.getLogger(__name__)
    # SMQTK level always at least INFO level for standard internals reporting.
    logging.getLogger("smqtk").setLevel(logging.INFO)
    # Enable DEBUG level on applicable namespaces available to us at this time.
    for ns in debug_ns_list:
        log.info("Enabling debug logging on '{}' namespace".format(ns))
        logging.getLogger(ns).setLevel(logging.DEBUG)

    webapp_types = smqtk.web.SmqtkWebApp.get_impls()
    web_applications = {t.__name__: t for t in webapp_types}

    if args.list:
        log.info("")
        log.info("Available applications:")
        log.info("")
        for l, cls in six.iteritems(web_applications):
            log.info("\t" + l)
            if debug_smqtk:
                log.info('\t' + ('^' * len(l)) + '\n' + cls.__doc__ + '\n' +
                         ('*' * 80) + '\n')
        log.info("")
        exit(0)

    application_name = args.application

    if application_name is None:
        log.error("No application name given!")
        exit(1)
    elif application_name not in web_applications:
        log.error("Invalid application label '%s'", application_name)
        exit(1)

    #: :type: smqtk.web.SmqtkWebApp
    app_class = web_applications[application_name]

    # If the application class's logger does not already report as having INFO/
    # DEBUG level logging (due to being a child of an above handled namespace)
    # then set the app namespace's logger level appropriately
    app_class_logger_level = app_class.get_logger().getEffectiveLevel()
    app_class_target_level = logging.INFO - (10 * debug_app)
    if app_class_logger_level > app_class_target_level:
        level_name = \
            "DEBUG" if app_class_target_level == logging.DEBUG else "INFO"
        log.info("Enabling '{}' logging for '{}' logger namespace.".format(
            level_name,
            app_class.get_logger().name))
        app_class.get_logger().setLevel(logging.INFO - (10 * debug_app))

    config = cli.utility_main_helper(app_class.get_default_config,
                                     args,
                                     skip_logging_init=True)

    host = args.host
    port = args.port and int(args.port)
    use_reloader = args.reload
    use_threading = args.threaded
    use_basic_auth = args.use_basic_auth
    use_simple_cors = args.use_simple_cors

    # noinspection PyUnresolvedReferences
    #: :type: smqtk.web.SmqtkWebApp
    app = app_class.from_config(config)
    if use_basic_auth:
        app.config["BASIC_AUTH_FORCE"] = True
        BasicAuth(app)
    if use_simple_cors:
        log.debug("Enabling CORS for all domains on all routes.")
        CORS(app)
    app.config['DEBUG'] = debug_server

    log.info("Starting application")
    app.run(host=host,
            port=port,
            debug=debug_server,
            use_reloader=use_reloader,
            threaded=use_threading)
Ejemplo n.º 14
0
def classifier_kfold_validation():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #
    # Load configurations / Setup data
    #
    pr_enabled = config['pr_curves']['enabled']
    pr_output_dir = config['pr_curves']['output_directory']
    pr_file_prefix = config['pr_curves']['file_prefix'] or ''
    pr_show = config['pr_curves']['show']

    roc_enabled = config['roc_curves']['enabled']
    roc_output_dir = config['roc_curves']['output_directory']
    roc_file_prefix = config['roc_curves']['file_prefix'] or ''
    roc_show = config['roc_curves']['show']

    log.info("Initializing DescriptorSet (%s)",
             config['plugins']['descriptor_set']['type'])
    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(config['plugins']['descriptor_set'],
                                      DescriptorSet.get_impls())
    log.info("Loading classifier configuration")
    #: :type: dict
    classifier_config = config['plugins']['supervised_classifier']

    # Always use in-memory ClassificationElement since we are retraining the
    # classifier and don't want possible element caching
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory(
        MemoryClassificationElement, {})

    log.info("Loading truth data")
    #: :type: list[str]
    uuids = []
    #: :type: list[str]
    truth_labels = []
    with open(config['cross_validation']['truth_labels']) as f:
        f_csv = csv.reader(f)
        for row in f_csv:
            uuids.append(row[0])
            truth_labels.append(row[1])
    #: :type: numpy.ndarray[str]
    uuids = numpy.array(uuids)
    #: :type: numpy.ndarray[str]
    truth_labels = numpy.array(truth_labels)

    #
    # Cross validation
    #
    kfolds = sklearn.model_selection.StratifiedKFold(
        n_splits=config['cross_validation']['num_folds'],
        shuffle=True,
        random_state=config['cross_validation']['random_seed'],
    ).split(numpy.zeros(len(truth_labels)), truth_labels)
    """
    Truth and classification probability results for test data per fold.
    Format:
        {
            0: {
                '<label>':  {
                    "truth": [...],   # Parallel truth and classification
                    "proba": [...],   # probability values
                },
                ...
            },
            ...
        }
    """
    fold_data: Dict[int, Any] = {}

    i = 0
    for train, test in kfolds:
        log.info("Fold %d", i)
        log.info("-- %d training examples", len(train))
        log.info("-- %d test examples", len(test))
        fold_data[i] = {}

        log.info("-- creating classifier")
        classifier = cast(
            SupervisedClassifier,
            from_config_dict(classifier_config,
                             SupervisedClassifier.get_impls()))

        log.info("-- gathering descriptors")
        pos_map: Dict[str, List[DescriptorElement]] = {}
        for idx in train:
            if truth_labels[idx] not in pos_map:
                pos_map[truth_labels[idx]] = []
            pos_map[truth_labels[idx]].append(
                descriptor_set.get_descriptor(uuids[idx]))

        log.info("-- Training classifier")
        classifier.train(pos_map)

        log.info("-- Classifying test set")
        c_iter = classifier.classify_elements(
            (descriptor_set.get_descriptor(uuids[idx]) for idx in test),
            classification_factory,
        )
        uuid2c = dict((c.uuid, c.get_classification()) for c in c_iter)

        log.info("-- Pairing truth and computed probabilities")
        # Only considering positive labels
        for t_label in pos_map:
            fold_data[i][t_label] = {
                "truth": [L == t_label for L in truth_labels[test]],
                "proba": [uuid2c[uuid][t_label] for uuid in uuids[test]]
            }

        i += 1

    #
    # Curve generation
    #
    if pr_enabled:
        make_pr_curves(fold_data, pr_output_dir, pr_file_prefix, pr_show)
    if roc_enabled:
        make_roc_curves(fold_data, roc_output_dir, roc_file_prefix, roc_show)
Ejemplo n.º 15
0
def main():
    args = cli_parser().parse_args()
    config = utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    report_size = args.report_size
    crawled_after = args.crawled_after
    inserted_after = args.inserted_after

    #
    # Check config properties
    #
    m = mimetypes.MimeTypes()
    # non-strict types (see use of ``guess_extension`` above)
    m_img_types = set(m.types_map_inv[0].keys() + m.types_map_inv[1].keys())
    if not isinstance(config['image_types'], list):
        raise ValueError("The 'image_types' property was not set to a list.")
    for t in config['image_types']:
        if ('image/' + t) not in m_img_types:
            raise ValueError("Image type '%s' is not a valid image MIMETYPE "
                             "sub-type." % t)

    if not report_size and args.output_dir is None:
        raise ValueError("Require an output directory!")
    if not report_size and args.file_list is None:
        raise ValueError("Require an output CSV file path!")

    #
    # Initialize ElasticSearch stuff
    #
    es_auth = None
    if config['elastic_search']['username'] and config['elastic_search'][
            'password']:
        es_auth = (config['elastic_search']['username'],
                   config['elastic_search']['password'])

    es = elasticsearch.Elasticsearch(
        config['elastic_search']['instance_address'],
        http_auth=es_auth,
        use_ssl=True,
        verify_certs=True,
        ca_certs=certifi.where(),
    )

    #
    # Query and Run
    #
    http_auth = None
    if config['stored_http_auth']['name'] and config['stored_http_auth'][
            'pass']:
        http_auth = (config['stored_http_auth']['name'],
                     config['stored_http_auth']['pass'])

    ts_re = re.compile('(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2})Z')
    if crawled_after:
        m = ts_re.match(crawled_after)
        if m is None:
            raise ValueError("Given 'crawled-after' timestamp not in correct "
                             "format: '%s'" % crawled_after)
        crawled_after = datetime.datetime(*[int(e) for e in m.groups()])
    if inserted_after:
        m = ts_re.match(inserted_after)
        if m is None:
            raise ValueError("Given 'inserted-after' timestamp not in correct "
                             "format: '%s'" % inserted_after)
        inserted_after = datetime.datetime(*[int(e) for e in m.groups()])

    q = cdr_images_after(es, config['elastic_search']['index'],
                         config['image_types'], crawled_after, inserted_after)

    log.info("Query Size: %d", q[0:0].execute().hits.total)
    if report_size:
        exit(0)

    fetch_cdr_query_images(q,
                           args.output_dir,
                           args.file_list,
                           cores=int(config['parallel']['cores']),
                           stored_http_auth=http_auth,
                           batch_size=int(
                               config['elastic_search']['batch_size']))
Ejemplo n.º 16
0
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    # Deprecations
    if (config.get('parallelism', {}).get('classification_cores', None)
            is not None):
        warnings.warn(
            "Usage of 'classification_cores' is deprecated. "
            "Classifier parallelism is not defined on a "
            "per-implementation basis. See classifier "
            "implementation parameterization.",
            category=DeprecationWarning)

    #
    # Initialize stuff from configuration
    #
    #: :type: smqtk.algorithms.Classifier
    classifier = from_config_dict(config['plugins']['classifier'],
                                  SupervisedClassifier.get_impls())
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])
    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(config['plugins']['descriptor_set'],
                                      DescriptorSet.get_impls())

    uuid2label_filepath = config['utility']['csv_filepath']
    do_train = config['utility']['train']
    output_uuid_cm = config['utility']['output_uuid_confusion_matrix']
    plot_filepath_pr = config['utility']['output_plot_pr']
    plot_filepath_roc = config['utility']['output_plot_roc']
    plot_filepath_cm = config['utility']['output_plot_confusion_matrix']
    plot_ci = config['utility']['curve_confidence_interval']
    plot_ci_alpha = config['utility']['curve_confidence_interval_alpha']

    #
    # Construct mapping of label to the DescriptorElement instances for that
    # described by that label.
    #
    log.info("Loading descriptors by UUID")

    def iter_uuid_label():
        """ Iterate through UUIDs in specified file """
        with open(uuid2label_filepath) as uuid2label_file:
            reader = csv.reader(uuid2label_file)
            for r in reader:
                # TODO: This will need to be updated to handle multiple labels
                #       per descriptor.
                yield r[0], r[1]

    def get_descr(r):
        """ Fetch descriptors from configured index """
        uuid, truth_label = r
        return truth_label, descriptor_set.get_descriptor(uuid)

    tlabel_element_iter = parallel.parallel_map(
        get_descr,
        iter_uuid_label(),
        name="cmv_get_descriptors",
        use_multiprocessing=True,
        cores=config['parallelism']['descriptor_fetch_cores'],
    )

    # Map of truth labels to descriptors of labeled data
    #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
    tlabel2descriptors = {}
    for tlabel, d in tlabel_element_iter:
        tlabel2descriptors.setdefault(tlabel, []).append(d)

    # Train classifier if the one given has a ``train`` method and training
    # was turned enabled.
    if do_train:
        log.info("Training supervised classifier model")
        classifier.train(tlabel2descriptors)
        exit(0)

    #
    # Apply classifier to descriptors for predictions
    #

    # Truth label to predicted classification results
    #: :type: dict[str, set[smqtk.representation.ClassificationElement]]
    tlabel2classifications = {}
    for tlabel, descriptors in six.iteritems(tlabel2descriptors):
        tlabel2classifications[tlabel] = \
            set(classifier.classify_elements(descriptors,
                                             classification_factory))
    log.info("Truth label counts:")
    for l in sorted(tlabel2classifications):
        log.info("  %s :: %d", l, len(tlabel2classifications[l]))

    #
    # Confusion Matrix
    #
    conf_mat, labels = gen_confusion_matrix(tlabel2classifications)
    log.info("Confusion_matrix")
    log_cm(log.info, conf_mat, labels)
    if plot_filepath_cm:
        plot_cm(conf_mat, labels, plot_filepath_cm)

    # Confusion Matrix of descriptor UUIDs to output json
    if output_uuid_cm:
        # Top dictionary keys are true labels, inner dictionary keys are UUID
        # predicted labels.
        log.info("Computing UUID Confusion Matrix")
        #: :type: dict[str, dict[collections.Hashable, set]]
        uuid_cm = {}
        for tlabel in tlabel2classifications:
            uuid_cm[tlabel] = collections.defaultdict(set)
            for c in tlabel2classifications[tlabel]:
                uuid_cm[tlabel][c.max_label()].add(c.uuid)
            # convert sets to lists for JSON output.
            for plabel in uuid_cm[tlabel]:
                # noinspection PyTypeChecker
                uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel])
        with open(output_uuid_cm, 'w') as f:
            log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm)
            json.dump(uuid_cm, f, indent=2, separators=(',', ': '))

    #
    # Create PR/ROC curves via scikit learn tools
    #
    if plot_filepath_pr:
        log.info("Making PR curve")
        make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci,
                       plot_ci_alpha)
    if plot_filepath_roc:
        log.info("Making ROC curve")
        make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci,
                        plot_ci_alpha)
Ejemplo n.º 17
0
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #
    # Load configuration contents
    #
    uuid_list_filepath = args.uuids_list
    report_interval = config['utility']['report_interval']
    use_multiprocessing = config['utility']['use_multiprocessing']

    #
    # Checking input parameters
    #
    if (uuid_list_filepath is not None) and \
            not os.path.isfile(uuid_list_filepath):
        raise ValueError("UUIDs list file does not exist!")

    #
    # Loading stuff
    #
    log.info("Loading descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = from_config_dict(config['plugins']['descriptor_index'],
                                        DescriptorIndex.get_impls())
    log.info("Loading LSH functor")
    #: :type: smqtk.algorithms.LshFunctor
    lsh_functor = from_config_dict(config['plugins']['lsh_functor'],
                                   LshFunctor.get_impls())
    log.info("Loading Key/Value store")
    #: :type: smqtk.representation.KeyValueStore
    hash2uuids_kvstore = from_config_dict(
        config['plugins']['hash2uuid_kvstore'], KeyValueStore.get_impls())

    # Iterate either over what's in the file given, or everything in the
    # configured index.
    def iter_uuids():
        if uuid_list_filepath:
            log.info("Using UUIDs list file")
            with open(uuid_list_filepath) as f:
                for l in f:
                    yield l.strip()
        else:
            log.info("Using all UUIDs resent in descriptor index")
            for k in descriptor_index.keys():
                yield k

    #
    # Compute codes
    #
    log.info("Starting hash code computation")
    kv_update = {}
    for uuid, hash_int in \
            compute_hash_codes(uuids_for_processing(iter_uuids(),
                                                    hash2uuids_kvstore),
                               descriptor_index, lsh_functor,
                               report_interval,
                               use_multiprocessing, True):
        # Get original value in KV-store if not in update dict.
        if hash_int not in kv_update:
            kv_update[hash_int] = hash2uuids_kvstore.get(hash_int, set())
        kv_update[hash_int] |= {uuid}

    if kv_update:
        log.info("Updating KV store... (%d keys)" % len(kv_update))
        hash2uuids_kvstore.add_many(kv_update)

    log.info("Done")