Пример #1
0
def main():
    log = logging.getLogger(__name__)
    parser = get_cli_parser()
    args = parser.parse_args()

    config_path = args.config
    generate_config = args.generate_config
    config_overwrite = args.overwrite
    is_debug = args.verbose

    label = args.label
    file_globs = args.file_globs

    initialize_logging(logging.getLogger('__main__'),
                       is_debug and logging.DEBUG or logging.INFO)
    initialize_logging(logging.getLogger('smqtk'), is_debug and logging.DEBUG
                       or logging.INFO)
    log.debug("Showing debug messages.")

    config = get_default_config()
    config_loaded = False
    if config_path and os.path.isfile(config_path):
        with open(config_path) as f:
            log.info("Loading configuration: %s", config_path)
            config.update(json.load(f))
        config_loaded = True
    output_config(generate_config, config, log, config_overwrite, 100)

    if not config_loaded:
        log.error("No configuration provided")
        exit(101)

    classify_files(config, label, file_globs)
Пример #2
0
def cli_group(verbose):
    """
    Tool for building a nearest neighbors index from an input descriptor set.

    The index is built, not updated. If the index configured must not be
    read-only and any persisted index, if already existing, may be overwritten.
    """
    llevel = logging.WARN - (10 * verbose)
    # Attempting just setting the root logger. If this becomes too verbose,
    # initially relevant namespaces manually.
    initialize_logging(logging.getLogger(), llevel)
    LOG.info("Displaying informational logging.")
    LOG.debug("Displaying debug logging.")
Пример #3
0
def main():
    parser = cli_parser()
    args = parser.parse_args()

    logging_level = logging.INFO
    if args.verbose:
        logging_level = logging.DEBUG
    initialize_logging(logging.getLogger("smqtk"), logging_level)

    base_dir = args.base_dir
    interval_seconds = args.interval
    expiry_seconds = args.expiry

    interval_scan(interval_seconds, base_dir, expiry_seconds,
                  remove_file_action)
def main():
    cli.initialize_logging(logging.getLogger(), logging.DEBUG)
    log = logging.getLogger(__name__)

    # For each file in descriptor vector file tree, load from file
    # [type, uuid, vector] and insert into PSQL element.

    log.info("Setting up parallel environment")
    in_queue = multiprocessing.Queue()
    workers = []
    for i in xrange(multiprocessing.cpu_count()):
        p = multiprocessing.Process(target=proc_transfer, args=(in_queue, ))
        workers.append(p)
        p.start()

    try:
        log.info("Loading filename list")
        with open("descriptor_file_names.5.3mil.pickle") as f:
            fname_list = cPickle.load(f)

        log.info("Running through filename list")
        for n in fname_list:
            m = fname_re.match(n)
            assert m

            type_str = m.group(1)
            uuid_str = m.group(2)

            #print type_str, uuid_str
            #break
            in_queue.put((type_str, uuid_str))

        log.info("Sending worker terminal packets")
        for w in workers:
            in_queue.put(None)

    except:
        log.info("Terminating workers")
        for w in workers:
            w.terminate()

    finally:
        log.info("Waiting for workers to complete")
        for w in workers:
            w.join()
        log.info("Workers joined")
Пример #5
0
def main():
    # Print help and exit if no arguments were passed
    if len(sys.argv) == 1:
        get_cli_parser().print_help()
        sys.exit(1)

    args = get_cli_parser().parse_args()
    llevel = logging.INFO if not args.verbose else logging.DEBUG
    initialize_logging(logging.getLogger('smqtk'), llevel)
    initialize_logging(logging.getLogger('__main__'), llevel)

    log = logging.getLogger(__name__)
    log.debug('Showing debug messages.')

    if args.file_list is not None and not os.path.exists(args.file_list):
        log.error('Invalid file list path: %s', args.file_list)
        exit(103)

    def check_image(image_path):
        if not os.path.exists(image_path):
            log.warn('Invalid image path given (does not exist): %s',
                     image_path)
            return False, False
        else:
            d = DataFileElement(image_path)
            return is_valid_element(d, check_image=True), d

    with open(args.file_list) as infile:
        checked_images = parallel.parallel_map(check_image,
                                               map(str.strip, infile),
                                               name='check-image-validity',
                                               use_multiprocessing=True)

        for is_valid, dfe in checked_images:
            if dfe:  # in the case of a non-existent file
                if (is_valid and not args.invert) or \
                        (not is_valid and args.invert):
                    # We know the callback above is creating DataFileElement
                    # instances.
                    # noinspection PyProtectedMember
                    print('%s,%s' % (dfe._filepath, dfe.uuid()))
Пример #6
0
import csv
import json
import logging

from matplotlib import pyplot as plt
import numpy
from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve

from smqtk.algorithms import get_classifier_impls
from smqtk.representation import ClassificationElementFactory
from smqtk.representation.classification_element.memory import MemoryClassificationElement
from smqtk.representation.descriptor_index.memory import MemoryDescriptorIndex
from smqtk.utils.cli import initialize_logging
from smqtk.utils.plugin import from_plugin_config

initialize_logging(logging.getLogger(), logging.INFO)
log = logging.getLogger(__name__)

###############################################################################
# Parameters
#
PHONE_SHA1_JSON = "eval.map.phone2shas.json"
DESCRIPTOR_INDEX_FILE_CACHE = "eval.images.descriptors.alexnet_fc7.index"

CLASSIFIER_TRAINING_CONFIG_JSON = 'ad-images.final.cmv.train.json'

PHONE2SCORE_OUTPUT_FILEPATH = "eval.results.full_model.phone2score.csv"

# Optional for ROC generation, using PHONE2SCORE_OUTPUT_FILEPATH as input, and
# outputting plots
PHONE2TRUTH = 'eval.source.phone2truth.json'
Пример #7
0
CAFFE_LABELS = "labels.txt"

# CSV file detailing [cluster_id, ad_id, image_sha1] relationships.
EVAL_CLUSTERS_ADS_IMAGES_CSV = "eval.CP1_clusters_ads_images.csv"
# json-lines file of clusters missing from the above file. Should be at least
# composed of: {"cluster_id": <str>, ... }
EVAL_MISSING_CLUSTERS = "eval.cluster_scores.missing_clusters.jl"

OUTPUT_DESCR_PROB_INDEX = "cp1_img_prob_descriptors.pickle"
OUTPUT_MAX_JL = "cp1_scores_max.jl"
OUTPUT_AVG_JL = "cp1_scores_avg.jl"

###############################################################################

# Compute classification scores
initialize_logging(logging.getLogger('smqtk'), logging.DEBUG)

eval_data_set = DataMemorySet(EVAL_DATASET)
img_prob_descr_index = MemoryDescriptorIndex(OUTPUT_DESCR_PROB_INDEX)

img_prob_gen = CaffeDescriptorGenerator(CAFFE_DEPLOY,
                                        CAFFE_MODEL,
                                        CAFFE_IMG_MEAN,
                                        'prob',
                                        batch_size=1000,
                                        use_gpu=True,
                                        load_truncated_images=True)

img_c_mem_factory = ClassificationElementFactory(MemoryClassificationElement,
                                                 {})
img_prob_classifier = IndexLabelClassifier(CAFFE_LABELS)
Пример #8
0
def main():
    args = cli_parser().parse_args()

    ui_config_filepath, iqr_config_filepath = args.config
    llevel = logging.DEBUG if args.verbose else logging.INFO
    tab = args.tab
    input_files_globs = args.input_files

    # Not using `cli.utility_main_helper`` due to deviating from single-
    # config-with-default usage.
    cli.initialize_logging(logging.getLogger('smqtk'), llevel)
    cli.initialize_logging(logging.getLogger('__main__'), llevel)
    log = logging.getLogger(__name__)

    log.info("Loading UI config: '{}'".format(ui_config_filepath))
    ui_config, ui_config_loaded = cli.load_config(ui_config_filepath)
    log.info("Loading IQR config: '{}'".format(iqr_config_filepath))
    iqr_config, iqr_config_loaded = cli.load_config(iqr_config_filepath)
    if not (ui_config_loaded and iqr_config_loaded):
        raise RuntimeError("One or both configuration files failed to load.")

    # Ensure the given "tab" exists in UI configuration.
    if tab is None:
        log.error("No configuration tab provided to drive model generation.")
        exit(1)
    if tab not in ui_config["iqr_tabs"]:
        log.error("Invalid tab provided: '{}'. Available tags: {}".format(
            tab, list(ui_config["iqr_tabs"])))
        exit(1)

    #
    # Gather Configurations
    #
    log.info("Extracting plugin configurations")

    ui_tab_config = ui_config["iqr_tabs"][tab]
    iqr_plugins_config = iqr_config['iqr_service']['plugins']

    # Configure DataSet implementation and parameters
    data_set_config = ui_tab_config['data_set']

    # Configure DescriptorElementFactory instance, which defines what
    # implementation of DescriptorElement to use for storing generated
    # descriptor vectors below.
    descriptor_elem_factory_config = iqr_plugins_config['descriptor_factory']

    # Configure DescriptorGenerator algorithm implementation, parameters and
    # persistent model component locations (if implementation has any).
    descriptor_generator_config = iqr_plugins_config['descriptor_generator']

    # Configure NearestNeighborIndex algorithm implementation, parameters and
    # persistent model component locations (if implementation has any).
    nn_index_config = iqr_plugins_config['neighbor_index']

    #
    # Initialize data/algorithms
    #
    # Constructing appropriate data structures and algorithms, needed for the
    # IQR demo application, in preparation for model training.
    #
    log.info("Instantiating plugins")
    #: :type: representation.DataSet
    data_set = \
        from_config_dict(data_set_config, representation.DataSet.get_impls())
    descriptor_elem_factory = \
        representation.DescriptorElementFactory \
        .from_config(descriptor_elem_factory_config)
    #: :type: algorithms.DescriptorGenerator
    descriptor_generator = \
        from_config_dict(descriptor_generator_config,
                         algorithms.DescriptorGenerator.get_impls())

    #: :type: algorithms.NearestNeighborsIndex
    nn_index = \
        from_config_dict(nn_index_config,
                         algorithms.NearestNeighborsIndex.get_impls())

    #
    # Build models
    #
    log.info("Adding files to dataset '{}'".format(data_set))
    for g in input_files_globs:
        g = osp.expanduser(g)
        if osp.isfile(g):
            data_set.add_data(DataFileElement(g, readonly=True))
        else:
            log.debug("Expanding glob: %s" % g)
            for fp in glob.iglob(g):
                data_set.add_data(DataFileElement(fp, readonly=True))

    # Generate a model if the generator defines a known generation method.
    try:
        log.debug("descriptor generator as model to generate?")
        descriptor_generator.generate_model(data_set)
    except AttributeError as ex:
        log.debug(
            "descriptor generator as model to generate - Nope: {}".format(
                str(ex)))

    # Generate descriptors of data for building NN index.
    log.info("Computing descriptors for data set with {}".format(
        descriptor_generator))
    data2descriptor = descriptor_generator.compute_descriptor_async(
        data_set, descriptor_elem_factory)

    # Possible additional support steps before building NNIndex
    try:
        # Fit the LSH index functor
        log.debug("Has LSH Functor to fit?")
        nn_index.lsh_functor.fit(six.itervalues(data2descriptor))
    except AttributeError as ex:
        log.debug("Has LSH Functor to fit - Nope: {}".format(str(ex)))

    log.info("Building nearest neighbors index {}".format(nn_index))
    nn_index.build_index(six.itervalues(data2descriptor))
Пример #9
0
import hashlib
import json
import logging
import mimetypes
import os
import requests
import StringIO
import uuid

from tika import detector as tika_detector

from smqtk.utils.cli import initialize_logging
from smqtk.utils.file import safe_create_dir
from smqtk.utils.parallel import parallel_map

initialize_logging(logging.getLogger('__main__'), logging.INFO)
initialize_logging(logging.getLogger('smqtk'), logging.INFO)
log = logging.getLogger(__name__)

if '.jfif' in mimetypes.types_map:
    del mimetypes.types_map['.jfif']
if '.jpe' in mimetypes.types_map:
    del mimetypes.types_map['.jpe']


def dl_ad_image(url, output_dir):
    """
    Returns (None, None, None) if failed, otherwise (url, filepath, sha1)
    """
    log = logging.getLogger(__name__)
Пример #10
0
def main():
    parser = cli_parser()
    args = parser.parse_args()

    debug_smqtk = args.debug_smqtk or args.verbose
    debug_server = args.debug_server or args.verbose
    debug_app = args.debug_app or args.verbose

    debug_ns_list = args.debug_ns
    debug_smqtk and debug_ns_list.append('smqtk')
    debug_server and debug_ns_list.append('werkzeug')

    # Create a single stream handler on the root, the level passed being
    # applied to the handler, and then set tuned levels on specific namespace
    # levels under root, which is reset to warning.
    cli.initialize_logging(logging.getLogger(), logging.DEBUG)
    logging.getLogger().setLevel(logging.WARN)
    log = logging.getLogger(__name__)
    # SMQTK level always at least INFO level for standard internals reporting.
    logging.getLogger("smqtk").setLevel(logging.INFO)
    # Enable DEBUG level on applicable namespaces available to us at this time.
    for ns in debug_ns_list:
        log.info("Enabling debug logging on '{}' namespace".format(ns))
        logging.getLogger(ns).setLevel(logging.DEBUG)

    webapp_types = smqtk.web.SmqtkWebApp.get_impls()
    web_applications = {t.__name__: t for t in webapp_types}

    if args.list:
        log.info("")
        log.info("Available applications:")
        log.info("")
        for l, cls in six.iteritems(web_applications):
            log.info("\t" + l)
            if debug_smqtk:
                log.info('\t' + ('^' * len(l)) + '\n' + cls.__doc__ + '\n' +
                         ('*' * 80) + '\n')
        log.info("")
        exit(0)

    application_name = args.application

    if application_name is None:
        log.error("No application name given!")
        exit(1)
    elif application_name not in web_applications:
        log.error("Invalid application label '%s'", application_name)
        exit(1)

    #: :type: smqtk.web.SmqtkWebApp
    app_class = web_applications[application_name]

    # If the application class's logger does not already report as having INFO/
    # DEBUG level logging (due to being a child of an above handled namespace)
    # then set the app namespace's logger level appropriately
    app_class_logger_level = app_class.get_logger().getEffectiveLevel()
    app_class_target_level = logging.INFO - (10 * debug_app)
    if app_class_logger_level > app_class_target_level:
        level_name = \
            "DEBUG" if app_class_target_level == logging.DEBUG else "INFO"
        log.info("Enabling '{}' logging for '{}' logger namespace.".format(
            level_name,
            app_class.get_logger().name))
        app_class.get_logger().setLevel(logging.INFO - (10 * debug_app))

    config = cli.utility_main_helper(app_class.get_default_config,
                                     args,
                                     skip_logging_init=True)

    host = args.host
    port = args.port and int(args.port)
    use_reloader = args.reload
    use_threading = args.threaded
    use_basic_auth = args.use_basic_auth
    use_simple_cors = args.use_simple_cors

    # noinspection PyUnresolvedReferences
    #: :type: smqtk.web.SmqtkWebApp
    app = app_class.from_config(config)
    if use_basic_auth:
        app.config["BASIC_AUTH_FORCE"] = True
        BasicAuth(app)
    if use_simple_cors:
        log.debug("Enabling CORS for all domains on all routes.")
        CORS(app)
    app.config['DEBUG'] = debug_server

    log.info("Starting application")
    app.run(host=host,
            port=port,
            debug=debug_server,
            use_reloader=use_reloader,
            threaded=use_threading)