Exemple #1
0
def main():
    parser = cli_parser()
    args = parser.parse_args()

    debug_smqtk = args.debug_smqtk or args.verbose
    debug_server = args.debug_server or args.verbose

    bin_utils.initialize_logging(logging.getLogger("__main__"),
                                 logging.INFO - (10 * debug_smqtk))
    bin_utils.initialize_logging(logging.getLogger("smqtk"),
                                 logging.INFO - (10 * debug_smqtk))
    bin_utils.initialize_logging(logging.getLogger("werkzeug"),
                                 logging.WARN - (20 * debug_server))
    log = logging.getLogger(__name__)

    web_applications = smqtk.web.get_web_applications()

    if args.list:
        log.info("")
        log.info("Available applications:")
        log.info("")
        for l in web_applications:
            log.info("\t" + l)
        log.info("")
        exit(0)

    application_name = args.application

    if application_name is None:
        log.error("No application name given!")
        exit(1)
    elif application_name not in web_applications:
        log.error("Invalid application label '%s'", application_name)
        exit(1)

    app_class = web_applications[application_name]

    bin_utils.utility_main_helper(app_class.get_default_config,
                                  args,
                                  skip_logging_init=True)

    host = args.host
    port = args.port and int(args.port)
    use_reloader = args.reload
    use_threading = args.threaded
    use_basic_auth = args.use_basic_auth

    # noinspection PyUnresolvedReferences
    app = app_class.from_config(config)
    if use_basic_auth:
        app.config["BASIC_AUTH_FORCE"] = True
        BasicAuth(app)
    app.config['DEBUG'] = debug_server

    app.run(host=host,
            port=port,
            debug=debug_server,
            use_reloader=use_reloader,
            threaded=use_threading)
Exemple #2
0
def main():
    description = """
    Utility script to transform a set of descriptors, specified by UUID, with
    matching class labels, to a test file usable by libSVM utilities for
    train/test experiments.

    The input CSV file is assumed to be of the format:

        uuid,label
        ...

    This is the same as the format requested for other scripts like
    ``classifier_model_validation.py``.

    This is very useful for searching for -c and -g parameter values for a
    training sample of data using the ``tools/grid.py`` script, found in the
    libSVM source tree. For example:

        <smqtk_source>/TPL/libsvm-3.1-custom/tools/grid.py \\
            -log2c -5,15,2 -log2c 3,-15,-2 -v 5 -out libsvm.grid.out \\
            -png libsvm.grid.png -t 0 -w1 3.46713615023 -w2 12.2613240418 \\
            output_of_this_script.txt
    """
    args, config = bin_utils.utility_main_helper(default_config, description,
                                                 extend_parser)
    log = logging.getLogger(__name__)

    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'], get_descriptor_index_impls())

    labels_filepath = args.f
    output_filepath = args.o

    # Run through labeled UUIDs in input file, getting the descriptor from the
    # configured index, applying the appropriate integer label and then writing
    # the formatted line out to the output file.
    input_uuid_labels = csv.reader(open(labels_filepath))

    with open(output_filepath, 'w') as ofile:
        label2int = {}
        next_int = 1
        uuids, labels = zip(*input_uuid_labels)

        log.info("Scanning input descriptors and labels")
        for i, (l, d) in enumerate(
                itertools.izip(labels,
                               descriptor_index.get_many_descriptors(uuids))):
            log.debug("%d %s", i, d.uuid())
            if l not in label2int:
                label2int[l] = next_int
                next_int += 1
            ofile.write("%d " % label2int[l] + ' '.join([
                "%d:%.12f" % (j + 1, f)
                for j, f in enumerate(d.vector()) if f != 0.0
            ]) + '\n')

    log.info("Integer label association:")
    for i, l in sorted((i, l) for l, i in label2int.iteritems()):
        log.info('\t%d :: %s', i, l)
Exemple #3
0
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorIndex [type=%s]",
             config['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['descriptor_index'],
        get_descriptor_index_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):
        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()
        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_index.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorIndex (count=%d)",
                 len(descriptor_index))
        d_iter = descriptor_index

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
def main():
    args = cli_parser().parse_args()
    config = utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    output_filepath = args.output_map
    if not output_filepath:
        raise ValueError("No path given for output map file (pickle).")

    #: :type: smqtk.representation.DescriptorIndex
    index = from_plugin_config(config['descriptor_index'],
                               get_descriptor_index_impls())
    mbkm = MiniBatchKMeans(verbose=args.verbose,
                           compute_labels=False,
                           **config['minibatch_kmeans_params'])
    initial_fit_size = int(config['initial_fit_size'])

    d_classes = mb_kmeans_build_apply(index, mbkm, initial_fit_size)

    log.info("Saving KMeans centroids to: %s",
             config['centroids_output_filepath_npy'])
    numpy.save(config['centroids_output_filepath_npy'], mbkm.cluster_centers_)

    log.info("Saving result classification map to: %s", output_filepath)
    safe_create_dir(os.path.dirname(output_filepath))
    with open(output_filepath, 'w') as f:
        cPickle.dump(d_classes, f, -1)

    log.info("Done")
def main():
    args = cli_parser().parse_args()
    config = utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    output_filepath = args.output_map
    if not output_filepath:
        raise ValueError("No path given for output map file (pickle).")

    #: :type: smqtk.representation.DescriptorIndex
    index = from_plugin_config(config['descriptor_index'],
                               get_descriptor_index_impls())
    mbkm = MiniBatchKMeans(verbose=args.verbose,
                           compute_labels=False,
                           **config['minibatch_kmeans_params'])
    initial_fit_size = int(config['initial_fit_size'])

    d_classes = mb_kmeans_build_apply(index, mbkm, initial_fit_size)

    log.info("Saving KMeans centroids to: %s",
             config['centroids_output_filepath_npy'])
    numpy.save(config['centroids_output_filepath_npy'], mbkm.cluster_centers_)

    log.info("Saving result classification map to: %s", output_filepath)
    safe_create_dir(os.path.dirname(output_filepath))
    with open(output_filepath, 'w') as f:
        cPickle.dump(d_classes, f, -1)

    log.info("Done")
Exemple #6
0
def main():
    args = cli_parser().parse_args()
    config = utility_main_helper(default_config, args)
    l = logging.getLogger(__name__)

    completed_files_fp = args.completed_files
    filelist_fp = args.file_list
    batch_size = args.batch_size
    check_image = args.check_image

    # Input checking
    if not filelist_fp:
        l.error("No file-list file specified")
        exit(102)
    elif not os.path.isfile(filelist_fp):
        l.error("Invalid file list path: %s", filelist_fp)
        exit(103)

    if not completed_files_fp:
        l.error("No complete files output specified")
        exit(104)

    if batch_size < 0:
        l.error("Batch size must be >= 0.")
        exit(105)

    run_file_list(config, filelist_fp, completed_files_fp, batch_size,
                  check_image)
def main():
    description = """
    Utility for fetching remotely stored image paths from the JPL Solr index.

    Files will be transferred with their entire containing directories. For
    example, if the file was stored in "/data/things/image.png" remotely, it
    will be transferred locally to "<output_dir>/data/things/image.png".

    Assumptions:
        - JPL MEMEX Solr index key structure
            - `id` == "file:<abs-filepath>"
            - `mainType` is the first component of the MIMETYPE
            - `indexedAt` timestamp
    """
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    paths_file = args.paths_file
    after_time = args.after_time
    before_time = args.before_time

    #
    # Check dir/file locations
    #
    if paths_file is None:
        raise ValueError("Need a file path to to output transferred file "
                         "paths!")

    file_utils.safe_create_dir(os.path.dirname(paths_file))

    #
    # Start collection
    #
    remote_paths = solr_image_paths(
        config['solr_address'],
        after_time or '*', before_time or '*',
        config['solr_username'], config['solr_password'],
        config['batch_size']
    )

    log.info("Writing file paths")
    s = [0] * 7
    with open(paths_file, 'w') as of:
        for rp in remote_paths:
            of.write(rp + '\n')
            bin_utils.report_progress(log.info, s, 1.)
    # Final report
    s[1] -= 1
    bin_utils.report_progress(log.info, s, 0)
def main():
    args = get_cli_parser().parse_args()
    config = utility_main_helper(get_default_config, args)

    log = logging.getLogger(__name__)
    log.debug("Showing debug messages.")

    iqr_state_fp = args.iqr_state

    if not os.path.isfile(iqr_state_fp):
        log.error("IQR Session info JSON filepath was invalid")
        exit(102)

    train_classifier_iqr(config, iqr_state_fp)
Exemple #9
0
def main():
    parser = cli_parser()
    args = parser.parse_args()

    # Default config options for this util are technically valid for running,
    # its just a bad authkey.
    config = bin_utils.utility_main_helper(default_config, args,
                                           default_config_valid=True)

    port = int(config['port'])
    authkey = str(config['authkey'])

    mgr = ProxyManager(('', port), authkey)
    mgr.get_server().serve_forever()
Exemple #10
0
def main():
    args = get_cli_parser().parse_args()
    config = utility_main_helper(get_default_config, args)

    log = logging.getLogger(__name__)
    log.debug("Showing debug messages.")

    iqr_state_fp = args.iqr_state

    if not os.path.isfile(iqr_state_fp):
        log.error("IQR Session info JSON filepath was invalid")
        exit(102)

    train_classifier_iqr(config, iqr_state_fp)
Exemple #11
0
def main():
    description = """
    Utility for fetching remotely stored image paths from the JPL Solr index.

    Files will be transferred with their entire containing directories. For
    example, if the file was stored in "/data/things/image.png" remotely, it
    will be transferred locally to "<output_dir>/data/things/image.png".

    Assumptions:
        - JPL MEMEX Solr index key structure
            - `id` == "file:<abs-filepath>"
            - `mainType` is the first component of the MIMETYPE
            - `indexedAt` timestamp
    """
    args, config = bin_utils.utility_main_helper(default_config, description,
                                                 extend_parser)
    log = logging.getLogger(__name__)

    paths_file = args.paths_file
    after_time = args.after_time
    before_time = args.before_time

    #
    # Check dir/file locations
    #
    if paths_file is None:
        raise ValueError("Need a file path to to output transferred file "
                         "paths!")

    file_utils.safe_create_dir(os.path.dirname(paths_file))

    #
    # Start collection
    #
    remote_paths = solr_image_paths(
        config['solr_address'],
        after_time or '*', before_time or '*',
        config['solr_username'], config['solr_password'],
        config['batch_size']
    )

    log.info("Writing file paths")
    s = [0] * 7
    with open(paths_file, 'w') as of:
        for rp in remote_paths:
            of.write(rp + '\n')
            bin_utils.report_progress(log.info, s, 1.)
    # Final report
    s[1] -= 1
    bin_utils.report_progress(log.info, s, 0)
def main():
    # Print help and exit if no arguments were passed
    if len(sys.argv) == 1:
        get_cli_parser().print_help()
        sys.exit(1)

    args = get_cli_parser().parse_args()
    config = utility_main_helper(get_default_config, args)

    log = logging.getLogger(__name__)
    log.debug('Showing debug messages.')

    #: :type: smqtk.representation.DescriptorIndex
    descriptor_set = plugin.from_plugin_config(
        config['plugins']['descriptor_set'], get_descriptor_index_impls()
    )
    #: :type: smqtk.algorithms.NearestNeighborsIndex
    nearest_neighbor_index = plugin.from_plugin_config(
        config['plugins']['nn_index'], get_nn_index_impls()
    )

    # noinspection PyShadowingNames
    def nearest_neighbors(descriptor, n):
        if n == 0:
            n = len(nearest_neighbor_index)

        uuids, descriptors = nearest_neighbor_index.nn(descriptor, n)
        # Strip first result (itself) and create list of (uuid, distance)
        return list(zip([x.uuid() for x in uuids[1:]], descriptors[1:]))

    if args.uuid_list is not None and not os.path.exists(args.uuid_list):
        log.error('Invalid file list path: %s', args.uuid_list)
        exit(103)
    elif args.num < 0:
        log.error('Number of nearest neighbors must be >= 0')
        exit(105)

    if args.uuid_list is not None:
        with open(args.uuid_list, 'r') as infile:
            for line in infile:
                descriptor = descriptor_set.get_descriptor(line.strip())
                print(descriptor.uuid())
                for neighbor in nearest_neighbors(descriptor, args.num):
                    print('%s,%f' % neighbor)
    else:
        for (uuid, descriptor) in descriptor_set.iteritems():
            print(uuid)
            for neighbor in nearest_neighbors(descriptor, args.num):
                print('%s,%f' % neighbor)
Exemple #13
0
def main():
    parser = cli_parser()
    args = parser.parse_args()

    # Default config options for this util are technically valid for running,
    # its just a bad authkey.
    config = bin_utils.utility_main_helper(default_config,
                                           args,
                                           default_config_valid=True)

    port = int(config['port'])
    authkey = str(config['authkey'])

    mgr = ProxyManager(('', port), authkey)
    mgr.get_server().serve_forever()
Exemple #14
0
def main():
    description = """
    Tool for training the ITQ functor algorithm's model on descriptors in an
    index.

    By default, we use all descriptors in the configured index
    (``uuids_list_filepath`` is not given a value).

    The ``uuids_list_filepath`` configuration property is optional and should
    be used to specify a sub-set of descriptors in the configured index to
    train on. This only works if the stored descriptors' UUID is a type of
    string.
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorIndex [type=%s]",
             config['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['descriptor_index'],
        get_descriptor_index_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):

        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()

        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_index.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorIndex (count=%d)",
                 len(descriptor_index))
        d_iter = descriptor_index

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
Exemple #15
0
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    api_root = config['tool']['girder_api_root']
    api_key = config['tool']['api_key']
    api_query_batch = config['tool']['api_query_batch']
    insert_batch_size = config['tool']['dataset_insert_batch_size']

    # Collect N folder/item/file references on CL and any files referenced.
    #: :type: list[str]
    ids_folder = args.folder
    #: :type: list[str]
    ids_item = args.item
    #: :type: list[str]
    ids_file = args.file

    if args.folder_list:
        with open(args.folder_list) as f:
            ids_folder.extend([fid.strip() for fid in f])
    if args.item_list:
        with open(args.item_list) as f:
            ids_item.extend([iid.strip() for iid in f])
    if args.file_list:
        with open(args.file_list) as f:
            ids_file.extend([fid.strip() for fid in f])

    #: :type: smqtk.representation.DataSet
    data_set = plugin.from_plugin_config(config['plugins']['data_set'],
                                         get_data_set_impls())

    batch = collections.deque()
    rps = [0] * 7
    for e in find_girder_files(api_root, ids_folder, ids_item, ids_file,
                               api_key, api_query_batch):
        batch.append(e)
        if insert_batch_size and len(batch) >= insert_batch_size:
            data_set.add_data(*batch)
            batch.clear()
        bin_utils.report_progress(log.info, rps, 1.0)

    if batch:
        data_set.add_data(*batch)

    log.info('Done')
Exemple #16
0
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    api_root = config['tool']['girder_api_root']
    api_key = config['tool']['api_key']
    api_query_batch = config['tool']['api_query_batch']
    insert_batch_size = config['tool']['dataset_insert_batch_size']

    # Collect N folder/item/file references on CL and any files referenced.
    #: :type: list[str]
    ids_folder = args.folder
    #: :type: list[str]
    ids_item = args.item
    #: :type: list[str]
    ids_file = args.file

    if args.folder_list:
        with open(args.folder_list) as f:
            ids_folder.extend([fid.strip() for fid in f])
    if args.item_list:
        with open(args.item_list) as f:
            ids_item.extend([iid.strip() for iid in f])
    if args.file_list:
        with open(args.file_list) as f:
            ids_file.extend([fid.strip() for fid in f])

    #: :type: smqtk.representation.DataSet
    data_set = plugin.from_plugin_config(config['plugins']['data_set'],
                                         get_data_set_impls())

    batch = collections.deque()
    rps = [0]*7
    for e in find_girder_files(api_root, ids_folder, ids_item, ids_file,
                               api_key, api_query_batch):
        batch.append(e)
        if insert_batch_size and len(batch) >= insert_batch_size:
            data_set.add_data(*batch)
            batch.clear()
        bin_utils.report_progress(log.info, rps, 1.0)

    if batch:
        data_set.add_data(*batch)

    log.info('Done')
Exemple #17
0
def main():
    description = """
    Tool for training the ITQ functor algorithm's model on descriptors in an
    index.

    By default, we use all descriptors in the configured index
    (``uuids_list_filepath`` is not given a value).

    The ``uuids_list_filepath`` configuration property is optional and should
    be used to specify a sub-set of descriptors in the configured index to
    train on. This only works if the stored descriptors' UUID is a type of
    string.
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorIndex [type=%s]",
             config['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['descriptor_index'],
        get_descriptor_index_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):
        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()
        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_index.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorIndex (count=%d)",
                 len(descriptor_index))
        d_iter = descriptor_index

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )

    labels_filepath = args.f
    output_filepath = args.o

    # Run through labeled UUIDs in input file, getting the descriptor from the
    # configured index, applying the appropriate integer label and then writing
    # the formatted line out to the output file.
    input_uuid_labels = csv.reader(open(labels_filepath))

    with open(output_filepath, 'w') as ofile:
        label2int = {}
        next_int = 1
        uuids, labels = list(zip(*input_uuid_labels))

        log.info("Scanning input descriptors and labels")
        for i, (l, d) in enumerate(
                    zip(labels, descriptor_index.get_many_descriptors(uuids))):
            log.debug("%d %s", i, d.uuid())
            if l not in label2int:
                label2int[l] = next_int
                next_int += 1
            ofile.write(
                "%d " % label2int[l] +
                ' '.join(["%d:%.12f" % (j+1, f)
                          for j, f in enumerate(d.vector())
                          if f != 0.0]) +
                '\n'
            )

    log.info("Integer label association:")
    for i, l in sorted((i, l) for l, i in six.iteritems(label2int)):
        log.info('\t%d :: %s', i, l)
Exemple #19
0
def main():
    description = """
    Descriptor computation helper utility. Checks dat content type with respect
    to the configured descriptor generator to skip content that does not match
    the accepted types. Optionally, we can additionally filter out image content
    whose image bytes we cannot load via ``PIL.Image.open``.
    """

    args, config = utility_main_helper(default_config, description,
                                       extend_parser)
    l = logging.getLogger(__name__)

    completed_files_fp = args.completed_files
    filelist_fp = args.file_list
    batch_size = args.batch_size
    check_image = args.check_image

    # Input checking
    if not filelist_fp:
        l.error("No file-list file specified")
        exit(102)
    elif not os.path.isfile(filelist_fp):
        l.error("Invalid file list path: %s", filelist_fp)
        exit(103)

    if not completed_files_fp:
        l.error("No complete files output specified")
        exit(104)

    if batch_size < 0:
        l.error("Batch size must be >= 0.")
        exit(105)

    run_file_list(
        config,
        filelist_fp,
        completed_files_fp,
        batch_size,
        check_image
    )
Exemple #20
0
def main():
    parser = cli_parser()
    args = parser.parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    log.debug("Script arguments:\n%s" % args)

    def iter_input_elements():
        for f in args.input_files:
            f = osp.expanduser(f)
            if osp.isfile(f):
                yield DataFileElement(f)
            else:
                log.debug("Expanding glob: %s" % f)
                for g in glob.glob(f):
                    yield DataFileElement(g)

    log.info("Adding elements to data set")
    #: :type: smqtk.representation.DataSet
    ds = plugin.from_plugin_config(config['data_set'], get_data_set_impls())
    ds.add_data(*iter_input_elements())
def main():
    parser = cli_parser()
    args = parser.parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    output_filepath = args.output_filepath
    overwrite = args.overwrite

    if not args.input_file:
        log.error("Failed to provide an input file path")
        exit(1)
    elif not os.path.isfile(args.input_file):
        log.error("Given path does not point to a file.")
        exit(1)

    input_filepath = args.input_file
    data_element = DataFileElement(input_filepath)

    factory = DescriptorElementFactory.from_config(
        config['descriptor_factory'])
    #: :type: smqtk.algorithms.descriptor_generator.DescriptorGenerator
    cd = plugin.from_plugin_config(config['content_descriptor'],
                                   get_descriptor_generator_impls())
    descr_elem = cd.compute_descriptor(data_element, factory, overwrite)
    vec = descr_elem.vector()

    if vec is None:
        log.error("Failed to generate a descriptor vector for the input data!")

    if output_filepath:
        numpy.save(output_filepath, vec)
    else:
        # Construct string, because numpy
        s = []
        # noinspection PyTypeChecker
        for f in vec:
            s.append('%15f' % f)
        print(' '.join(s))
Exemple #22
0
def main():
    parser = cli_parser()
    args = parser.parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    log.debug("Script arguments:\n%s" % args)

    def iter_input_elements():
        for f in args.input_files:
            f = osp.expanduser(f)
            if osp.isfile(f):
                yield DataFileElement(f)
            else:
                log.debug("Expanding glob: %s" % f)
                for g in glob.glob(f):
                    yield DataFileElement(g)

    log.info("Adding elements to data set")
    #: :type: smqtk.representation.DataSet
    ds = plugin.from_plugin_config(config['data_set'], get_data_set_impls())
    ds.add_data(*iter_input_elements())
Exemple #23
0
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'], get_descriptor_index_impls())

    labels_filepath = args.f
    output_filepath = args.o

    # Run through labeled UUIDs in input file, getting the descriptor from the
    # configured index, applying the appropriate integer label and then writing
    # the formatted line out to the output file.
    input_uuid_labels = csv.reader(open(labels_filepath))

    with open(output_filepath, 'w') as ofile:
        label2int = {}
        next_int = 1
        uuids, labels = zip(*input_uuid_labels)

        log.info("Scanning input descriptors and labels")
        for i, (l, d) in enumerate(
                itertools.izip(labels,
                               descriptor_index.get_many_descriptors(uuids))):
            log.debug("%d %s", i, d.uuid())
            if l not in label2int:
                label2int[l] = next_int
                next_int += 1
            ofile.write("%d " % label2int[l] + ' '.join([
                "%d:%.12f" % (j + 1, f)
                for j, f in enumerate(d.vector()) if f != 0.0
            ]) + '\n')

    log.info("Integer label association:")
    for i, l in sorted((i, l) for l, i in label2int.iteritems()):
        log.info('\t%d :: %s', i, l)
Exemple #24
0
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorIndex [type=%s]",
             config['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['descriptor_index'],
        get_descriptor_index_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):

        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()

        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_index.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorIndex (count=%d)",
                 len(descriptor_index))
        d_iter = descriptor_index

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
Exemple #25
0
def main():
    description = """
    Compute LSH hash codes based on the provided functor on specific
    descriptors from the configured index given a file-list of UUIDs.

    When using an input file-list of UUIDs, we require that the UUIDs of
    indexed descriptors be strings, or equality comparable to the UUIDs' string
    representation.

    This script can be used to live update the ``hash2uuid_cache_filepath``
    model file for the ``LSHNearestNeighborIndex`` algorithm as output
    dictionary format is the same as used by that implementation.
    """
    args, config = bin_utils.utility_main_helper(default_config, description,
                                                 extend_parser)
    log = logging.getLogger(__name__)

    #
    # Load configuration contents
    #
    uuid_list_filepath = args.uuids_list
    hash2uuids_input_filepath = args.input_hash2uuids
    hash2uuids_output_filepath = args.output_hash2uuids
    report_interval = config['utility']['report_interval']
    use_multiprocessing = config['utility']['use_multiprocessing']
    pickle_protocol = config['utility']['pickle_protocol']

    #
    # Checking parameters
    #
    if not hash2uuids_output_filepath:
        raise ValueError("No hash2uuids map output file provided!")

    #
    # Loading stuff
    #
    log.info("Loading descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )
    log.info("Loading LSH functor")
    #: :type: smqtk.algorithms.LshFunctor
    lsh_functor = plugin.from_plugin_config(
        config['plugins']['lsh_functor'],
        get_lsh_functor_impls()
    )

    def iter_uuids():
        if uuid_list_filepath:
            log.info("Using UUIDs list file")
            with open(uuid_list_filepath) as f:
                for l in f:
                    yield l.strip()
        else:
            log.info("Using all UUIDs resent in descriptor index")
            for k in descriptor_index.iterkeys():
                yield k

    # load map if it exists, else start with empty dictionary
    if hash2uuids_input_filepath and os.path.isfile(hash2uuids_input_filepath):
        log.info("Loading hash2uuids mapping")
        with open(hash2uuids_input_filepath) as f:
            hash2uuids = cPickle.load(f)
    else:
        log.info("Creating new hash2uuids mapping for output")
        hash2uuids = {}

    #
    # Compute codes
    #
    log.info("Starting hash code computation")
    compute_hash_codes(
        uuids_for_processing(iter_uuids(), hash2uuids),
        descriptor_index,
        lsh_functor,
        hash2uuids,
        report_interval=report_interval,
        use_mp=use_multiprocessing,
    )

    #
    # Output results
    #
    tmp_output_filepath = hash2uuids_output_filepath + '.WRITING'
    log.info("Writing hash-to-uuids map to disk: %s", tmp_output_filepath)
    file_utils.safe_create_dir(os.path.dirname(hash2uuids_output_filepath))
    with open(tmp_output_filepath, 'wb') as f:
        cPickle.dump(hash2uuids, f, pickle_protocol)
    log.info("Moving on top of input: %s", hash2uuids_output_filepath)
    os.rename(tmp_output_filepath, hash2uuids_output_filepath)
    log.info("Done")
Exemple #26
0
def main():
    parser = cli_parser()
    args = parser.parse_args()

    debug_smqtk = args.debug_smqtk or args.verbose
    debug_server = args.debug_server or args.verbose

    bin_utils.initialize_logging(logging.getLogger("__main__"),
                                 logging.INFO - (10 * debug_smqtk))
    bin_utils.initialize_logging(logging.getLogger("smqtk"),
                                 logging.INFO - (10*debug_smqtk))
    bin_utils.initialize_logging(logging.getLogger("werkzeug"),
                                 logging.WARN - (20*debug_server))
    log = logging.getLogger(__name__)

    web_applications = smqtk.web.get_web_applications()

    if args.list:
        log.info("")
        log.info("Available applications:")
        log.info("")
        for l, cls in six.iteritems(web_applications):
            log.info("\t" + l)
            if debug_smqtk:
                log.info('\t' + ('^'*len(l)) + '\n' +

                         cls.__doc__ + '\n' +
                         ('*' * 80) + '\n')
        log.info("")
        exit(0)

    application_name = args.application

    if application_name is None:
        log.error("No application name given!")
        exit(1)
    elif application_name not in web_applications:
        log.error("Invalid application label '%s'", application_name)
        exit(1)

    #: :type: smqtk.web.SmqtkWebApp
    app_class = web_applications[application_name]

    config = bin_utils.utility_main_helper(app_class.get_default_config, args,
                                           skip_logging_init=True)

    host = args.host
    port = args.port and int(args.port)
    use_reloader = args.reload
    use_threading = args.threaded
    use_basic_auth = args.use_basic_auth

    # noinspection PyUnresolvedReferences
    #: :type: smqtk.web.SmqtkWebApp
    app = app_class.from_config(config)
    if use_basic_auth:
        app.config["BASIC_AUTH_FORCE"] = True
        BasicAuth(app)
    app.config['DEBUG'] = debug_server

    log.info("Starting application")
    app.run(host=host, port=port, debug=debug_server, use_reloader=use_reloader,
            threaded=use_threading)
Exemple #27
0
def main():
    description = """
    Utility for validating a given classifier implementation's model against
    some labeled testing data, outputting PR and ROC curve plots with
    area-under-curve score values.

    This utility can optionally be used train a supervised classifier model if
    the given classifier model configuration does not exist and a second CSV
    file listing labeled training data is provided. Training will be attempted
    if ``train`` is set to true. If training is performed, we exit after
    training completes. A ``SupervisedClassifier`` sub-classing implementation
    must be configured

    We expect the test and train CSV files in the column format:

        ...
        <UUID>,<label>
        ...

    The UUID is of the descriptor to which the label applies. The label may be
    any arbitrary string value, but all labels must be consistent in
    application.

    Some metrics presented assume the highest confidence class as the single
    predicted class for an element:

        - confusion matrix

    The output UUID confusion matrix is a JSON dictionary where the top-level
    keys are the true labels, and the inner dictionary is the mapping of
    predicted labels to the UUIDs of the classifications/descriptors that
    yielded the prediction. Again, this is based on the maximum probability
    label for a classification result (T=0.5).
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    #
    # Initialize stuff from configuration
    #
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(config['plugins']['classifier'],
                                           get_classifier_impls())
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'], get_descriptor_index_impls())

    uuid2label_filepath = config['utility']['csv_filepath']
    do_train = config['utility']['train']
    output_uuid_cm = config['utility']['output_uuid_confusion_matrix']
    plot_filepath_pr = config['utility']['output_plot_pr']
    plot_filepath_roc = config['utility']['output_plot_roc']
    plot_filepath_cm = config['utility']['output_plot_confusion_matrix']
    plot_ci = config['utility']['curve_confidence_interval']
    plot_ci_alpha = config['utility']['curve_confidence_interval_alpha']

    #
    # Construct mapping of label to the DescriptorElement instances for that
    # described by that label.
    #
    log.info("Loading descriptors by UUID")

    def iter_uuid_label():
        """ Iterate through UUIDs in specified file """
        with open(uuid2label_filepath) as uuid2label_file:
            reader = csv.reader(uuid2label_file)
            for r in reader:
                # TODO: This will need to be updated to handle multiple labels
                #       per descriptor.
                yield r[0], r[1]

    def get_descr(r):
        """ Fetch descriptors from configured index """
        uuid, truth_label = r
        return truth_label, descriptor_index.get_descriptor(uuid)

    tlabel_element_iter = parallel.parallel_map(
        get_descr,
        iter_uuid_label(),
        name="cmv_get_descriptors",
        use_multiprocessing=True,
        cores=config['parallelism']['descriptor_fetch_cores'],
    )

    # Map of truth labels to descriptors of labeled data
    #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
    tlabel2descriptors = {}
    for tlabel, d in tlabel_element_iter:
        tlabel2descriptors.setdefault(tlabel, []).append(d)

    # Train classifier if the one given has a ``train`` method and training
    # was turned enabled.
    if do_train:
        if isinstance(classifier, SupervisedClassifier):
            log.info("Training classifier model")
            classifier.train(tlabel2descriptors)
            exit(0)
        else:
            ValueError("Configured classifier is not a SupervisedClassifier "
                       "type and does not support training.")

    #
    # Apply classifier to descriptors for predictions
    #

    # Truth label to predicted classification results
    #: :type: dict[str, set[smqtk.representation.ClassificationElement]]
    tlabel2classifications = {}
    for tlabel, descriptors in tlabel2descriptors.iteritems():
        tlabel2classifications[tlabel] = \
            set(classifier.classify_async(
                descriptors, classification_factory,
                use_multiprocessing=True,
                procs=config['parallelism']['classification_cores'],
                ri=1.0,
            ).values())
    log.info("Truth label counts:")
    for l in sorted(tlabel2classifications):
        log.info("  %s :: %d", l, len(tlabel2classifications[l]))

    #
    # Confusion Matrix
    #
    conf_mat, labels = gen_confusion_matrix(tlabel2classifications)
    log.info("Confusion_matrix")
    log_cm(log.info, conf_mat, labels)
    if plot_filepath_cm:
        plot_cm(conf_mat, labels, plot_filepath_cm)

    # CM of descriptor UUIDs to output json
    if output_uuid_cm:
        # Top dictionary keys are true labels, inner dictionary keys are UUID
        # predicted labels.
        log.info("Computing UUID Confusion Matrix")
        #: :type: dict[str, dict[str, set | list]]
        uuid_cm = {}
        for tlabel in tlabel2classifications:
            uuid_cm[tlabel] = collections.defaultdict(set)
            for c in tlabel2classifications[tlabel]:
                uuid_cm[tlabel][c.max_label()].add(c.uuid)
            # convert sets to lists
            for plabel in uuid_cm[tlabel]:
                uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel])
        with open(output_uuid_cm, 'w') as f:
            log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm)
            json.dump(uuid_cm, f, indent=2, separators=(',', ': '))

    #
    # Create PR/ROC curves via scikit learn tools
    #
    if plot_filepath_pr:
        log.info("Making PR curve")
        make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci,
                       plot_ci_alpha)
    if plot_filepath_roc:
        log.info("Making ROC curve")
        make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci,
                        plot_ci_alpha)
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #
    # Initialize stuff from configuration
    #
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(
        config['plugins']['classifier'],
        get_classifier_impls()
    )
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory']
    )
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )

    uuid2label_filepath = config['utility']['csv_filepath']
    do_train = config['utility']['train']
    output_uuid_cm = config['utility']['output_uuid_confusion_matrix']
    plot_filepath_pr = config['utility']['output_plot_pr']
    plot_filepath_roc = config['utility']['output_plot_roc']
    plot_filepath_cm = config['utility']['output_plot_confusion_matrix']
    plot_ci = config['utility']['curve_confidence_interval']
    plot_ci_alpha = config['utility']['curve_confidence_interval_alpha']

    #
    # Construct mapping of label to the DescriptorElement instances for that
    # described by that label.
    #
    log.info("Loading descriptors by UUID")

    def iter_uuid_label():
        """ Iterate through UUIDs in specified file """
        with open(uuid2label_filepath) as uuid2label_file:
            reader = csv.reader(uuid2label_file)
            for r in reader:
                # TODO: This will need to be updated to handle multiple labels
                #       per descriptor.
                yield r[0], r[1]

    def get_descr(r):
        """ Fetch descriptors from configured index """
        uuid, truth_label = r
        return truth_label, descriptor_index.get_descriptor(uuid)

    tlabel_element_iter = parallel.parallel_map(
        get_descr, iter_uuid_label(),
        name="cmv_get_descriptors",
        use_multiprocessing=True,
        cores=config['parallelism']['descriptor_fetch_cores'],
    )

    # Map of truth labels to descriptors of labeled data
    #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
    tlabel2descriptors = {}
    for tlabel, d in tlabel_element_iter:
        tlabel2descriptors.setdefault(tlabel, []).append(d)

    # Train classifier if the one given has a ``train`` method and training
    # was turned enabled.
    if do_train:
        if isinstance(classifier, SupervisedClassifier):
            log.info("Training classifier model")
            classifier.train(tlabel2descriptors)
            exit(0)
        else:
            ValueError("Configured classifier is not a SupervisedClassifier "
                       "type and does not support training.")

    #
    # Apply classifier to descriptors for predictions
    #

    # Truth label to predicted classification results
    #: :type: dict[str, set[smqtk.representation.ClassificationElement]]
    tlabel2classifications = {}
    for tlabel, descriptors in six.iteritems(tlabel2descriptors):
        tlabel2classifications[tlabel] = \
            set(classifier.classify_async(
                descriptors, classification_factory,
                use_multiprocessing=True,
                procs=config['parallelism']['classification_cores'],
                ri=1.0,
            ).values())
    log.info("Truth label counts:")
    for l in sorted(tlabel2classifications):
        log.info("  %s :: %d", l, len(tlabel2classifications[l]))

    #
    # Confusion Matrix
    #
    conf_mat, labels = gen_confusion_matrix(tlabel2classifications)
    log.info("Confusion_matrix")
    log_cm(log.info, conf_mat, labels)
    if plot_filepath_cm:
        plot_cm(conf_mat, labels, plot_filepath_cm)

    # Confusion Matrix of descriptor UUIDs to output json
    if output_uuid_cm:
        # Top dictionary keys are true labels, inner dictionary keys are UUID
        # predicted labels.
        log.info("Computing UUID Confusion Matrix")
        #: :type: dict[str, dict[collections.Hashable, set]]
        uuid_cm = {}
        for tlabel in tlabel2classifications:
            uuid_cm[tlabel] = collections.defaultdict(set)
            for c in tlabel2classifications[tlabel]:
                uuid_cm[tlabel][c.max_label()].add(c.uuid)
            # convert sets to lists
            for plabel in uuid_cm[tlabel]:
                uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel])
        with open(output_uuid_cm, 'w') as f:
            log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm)
            json.dump(uuid_cm, f, indent=2, separators=(',', ': '))

    #
    # Create PR/ROC curves via scikit learn tools
    #
    if plot_filepath_pr:
        log.info("Making PR curve")
        make_pr_curves(tlabel2classifications, plot_filepath_pr,
                       plot_ci, plot_ci_alpha)
    if plot_filepath_roc:
        log.info("Making ROC curve")
        make_roc_curves(tlabel2classifications, plot_filepath_roc,
                        plot_ci, plot_ci_alpha)
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    # - parallel_map UUIDs to load from the configured index
    # - classify iterated descriptors

    uuids_list_filepath = args.uuids_list
    output_csv_filepath = args.csv_data
    output_csv_header_filepath = args.csv_header
    classify_overwrite = config['utility']['classify_overwrite']

    p_use_multiprocessing = \
        config['utility']['parallel']['use_multiprocessing']
    p_index_extraction_cores = \
        config['utility']['parallel']['index_extraction_cores']
    p_classification_cores = \
        config['utility']['parallel']['classification_cores']

    if not uuids_list_filepath:
        raise ValueError("No uuids_list_filepath specified.")
    elif not os.path.isfile(uuids_list_filepath):
        raise ValueError("Given uuids_list_filepath did not point to a file.")
    if output_csv_header_filepath is None:
        raise ValueError("Need a path to save CSV header labels")
    if output_csv_filepath is None:
        raise ValueError("Need a path to save CSV data.")

    #
    # Initialize configured plugins
    #

    log.info("Initializing descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'], get_descriptor_index_impls())

    log.info("Initializing classification factory")
    c_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])

    log.info("Initializing classifier")
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(config['plugins']['classifier'],
                                           get_classifier_impls())

    #
    # Setup/Process
    #
    def iter_uuids():
        with open(uuids_list_filepath) as f:
            for l in f:
                yield l.strip()

    def descr_for_uuid(uuid):
        """
        :type uuid: collections.Hashable
        :rtype: smqtk.representation.DescriptorElement
        """
        return descriptor_index.get_descriptor(uuid)

    def classify_descr(d):
        """
        :type d: smqtk.representation.DescriptorElement
        :rtype: smqtk.representation.ClassificationElement
        """
        return classifier.classify(d, c_factory, classify_overwrite)

    log.info("Initializing uuid-to-descriptor parallel map")
    #: :type: collections.Iterable[smqtk.representation.DescriptorElement]
    element_iter = parallel.parallel_map(
        descr_for_uuid,
        iter_uuids(),
        use_multiprocessing=p_use_multiprocessing,
        cores=p_index_extraction_cores,
        name="descr_for_uuid",
    )

    log.info("Initializing descriptor-to-classification parallel map")
    #: :type: collections.Iterable[smqtk.representation.ClassificationElement]
    classification_iter = parallel.parallel_map(
        classify_descr,
        element_iter,
        use_multiprocessing=p_use_multiprocessing,
        cores=p_classification_cores,
        name='classify_descr',
    )

    #
    # Write/Output files
    #

    c_labels = classifier.get_labels()

    def make_row(c):
        """
        :type c: smqtk.representation.ClassificationElement
        """
        c_m = c.get_classification()
        return [c.uuid] + [c_m[l] for l in c_labels]

    # column labels file
    log.info("Writing CSV column header file: %s", output_csv_header_filepath)
    file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath))
    with open(output_csv_header_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        w.writerow(['uuid'] + c_labels)

    # CSV file
    log.info("Writing CSV data file: %s", output_csv_filepath)
    file_utils.safe_create_dir(os.path.dirname(output_csv_filepath))
    r_state = [0] * 7
    with open(output_csv_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        for c in classification_iter:
            w.writerow(make_row(c))
            bin_utils.report_progress(log.info, r_state, 1.0)

    # Final report
    r_state[1] -= 1
    bin_utils.report_progress(log.info, r_state, 0)

    log.info("Done")
def main():
    description = """
    Utility for validating a given classifier implementation's model against
    some labeled testing data, outputting PR and ROC curve plots with
    area-under-curve score values.

    This utility can optionally be used train a supervised classifier model if
    the given classifier model configuration does not exist and a second CSV
    file listing labeled training data is provided. Training will be attempted
    if ``train`` is set to true. If training is performed, we exit after
    training completes. A ``SupervisedClassifier`` sub-classing implementation
    must be configured

    We expect the test and train CSV files in the column format:

        ...
        <UUID>,<label>
        ...

    The UUID is of the descriptor to which the label applies. The label may be
    any arbitrary string value, but all labels must be consistent in
    application.

    Some metrics presented assume the highest confidence class as the single
    predicted class for an element:

        - confusion matrix

    The output UUID confusion matrix is a JSON dictionary where the top-level
    keys are the true labels, and the inner dictionary is the mapping of
    predicted labels to the UUIDs of the classifications/descriptors that
    yielded the prediction. Again, this is based on the maximum probability
    label for a classification result (T=0.5).
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    #
    # Initialize stuff from configuration
    #
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(
        config['plugins']['classifier'],
        get_classifier_impls()
    )
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory']
    )
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )

    uuid2label_filepath = config['utility']['csv_filepath']
    do_train = config['utility']['train']
    output_uuid_cm = config['utility']['output_uuid_confusion_matrix']
    plot_filepath_pr = config['utility']['output_plot_pr']
    plot_filepath_roc = config['utility']['output_plot_roc']
    plot_filepath_cm = config['utility']['output_plot_confusion_matrix']
    plot_ci = config['utility']['curve_confidence_interval']
    plot_ci_alpha = config['utility']['curve_confidence_interval_alpha']

    #
    # Construct mapping of label to the DescriptorElement instances for that
    # described by that label.
    #
    log.info("Loading descriptors by UUID")

    def iter_uuid_label():
        """ Iterate through UUIDs in specified file """
        with open(uuid2label_filepath) as uuid2label_file:
            reader = csv.reader(uuid2label_file)
            for r in reader:
                # TODO: This will need to be updated to handle multiple labels
                #       per descriptor.
                yield r[0], r[1]

    def get_descr(r):
        """ Fetch descriptors from configured index """
        uuid, truth_label = r
        return truth_label, descriptor_index.get_descriptor(uuid)

    tlabel_element_iter = parallel.parallel_map(
        get_descr, iter_uuid_label(),
        name="cmv_get_descriptors",
        use_multiprocessing=True,
        cores=config['parallelism']['descriptor_fetch_cores'],
    )

    # Map of truth labels to descriptors of labeled data
    #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
    tlabel2descriptors = {}
    for tlabel, d in tlabel_element_iter:
        tlabel2descriptors.setdefault(tlabel, []).append(d)

    # Train classifier if the one given has a ``train`` method and training
    # was turned enabled.
    if do_train:
        if isinstance(classifier, SupervisedClassifier):
            log.info("Training classifier model")
            classifier.train(tlabel2descriptors)
            exit(0)
        else:
            ValueError("Configured classifier is not a SupervisedClassifier "
                       "type and does not support training.")

    #
    # Apply classifier to descriptors for predictions
    #

    # Truth label to predicted classification results
    #: :type: dict[str, set[smqtk.representation.ClassificationElement]]
    tlabel2classifications = {}
    for tlabel, descriptors in tlabel2descriptors.iteritems():
        tlabel2classifications[tlabel] = \
            set(classifier.classify_async(
                descriptors, classification_factory,
                use_multiprocessing=True,
                procs=config['parallelism']['classification_cores'],
                ri=1.0,
            ).values())
    log.info("Truth label counts:")
    for l in sorted(tlabel2classifications):
        log.info("  %s :: %d", l, len(tlabel2classifications[l]))

    #
    # Confusion Matrix
    #
    conf_mat, labels = gen_confusion_matrix(tlabel2classifications)
    log.info("Confusion_matrix")
    log_cm(log.info, conf_mat, labels)
    if plot_filepath_cm:
        plot_cm(conf_mat, labels, plot_filepath_cm)

    # CM of descriptor UUIDs to output json
    if output_uuid_cm:
        # Top dictionary keys are true labels, inner dictionary keys are UUID
        # predicted labels.
        log.info("Computing UUID Confusion Matrix")
        #: :type: dict[str, dict[str, set | list]]
        uuid_cm = {}
        for tlabel in tlabel2classifications:
            uuid_cm[tlabel] = collections.defaultdict(set)
            for c in tlabel2classifications[tlabel]:
                uuid_cm[tlabel][c.max_label()].add(c.uuid)
            # convert sets to lists
            for plabel in uuid_cm[tlabel]:
                uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel])
        with open(output_uuid_cm, 'w') as f:
            log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm)
            json.dump(uuid_cm, f, indent=2, separators=(',', ': '))


    #
    # Create PR/ROC curves via scikit learn tools
    #
    if plot_filepath_pr:
        log.info("Making PR curve")
        make_pr_curves(tlabel2classifications, plot_filepath_pr,
                       plot_ci, plot_ci_alpha)
    if plot_filepath_roc:
        log.info("Making ROC curve")
        make_roc_curves(tlabel2classifications, plot_filepath_roc,
                        plot_ci, plot_ci_alpha)
def main():
    description = """
    Script for asynchronously computing classifications for DescriptorElements
    in a DescriptorIndex specified via a list of UUIDs. Results are output to a
    CSV file in the format:

        uuid, label1_confidence, label2_confidence, ...

    CSV columns labels are output to the given CSV header file path. Label
    columns will be in the order as reported by the classifier implementations
    ``get_labels`` method.

    Due to using an input file-list of UUIDs, we require that the UUIDs of
    indexed descriptors be strings, or equality comparable to the UUIDs' string
    representation.
    """

    args, config = bin_utils.utility_main_helper(
        default_config,
        description,
        extend_parser,
    )
    log = logging.getLogger(__name__)

    # - parallel_map UUIDs to load from the configured index
    # - classify iterated descriptors

    uuids_list_filepath = args.uuids_list
    output_csv_filepath = args.csv_data
    output_csv_header_filepath = args.csv_header
    classify_overwrite = config['utility']['classify_overwrite']

    p_use_multiprocessing = \
        config['utility']['parallel']['use_multiprocessing']
    p_index_extraction_cores = \
        config['utility']['parallel']['index_extraction_cores']
    p_classification_cores = \
        config['utility']['parallel']['classification_cores']

    if not uuids_list_filepath:
        raise ValueError("No uuids_list_filepath specified.")
    elif not os.path.isfile(uuids_list_filepath):
        raise ValueError("Given uuids_list_filepath did not point to a file.")
    if output_csv_header_filepath is None:
        raise ValueError("Need a path to save CSV header labels")
    if output_csv_filepath is None:
        raise ValueError("Need a path to save CSV data.")

    #
    # Initialize configured plugins
    #

    log.info("Initializing descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )

    log.info("Initializing classification factory")
    c_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory']
    )

    log.info("Initializing classifier")
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(
        config['plugins']['classifier'], get_classifier_impls()
    )

    #
    # Setup/Process
    #
    def iter_uuids():
        with open(uuids_list_filepath) as f:
            for l in f:
                yield l.strip()

    def descr_for_uuid(uuid):
        """
        :type uuid: collections.Hashable
        :rtype: smqtk.representation.DescriptorElement
        """
        return descriptor_index.get_descriptor(uuid)

    def classify_descr(d):
        """
        :type d: smqtk.representation.DescriptorElement
        :rtype: smqtk.representation.ClassificationElement
        """
        return classifier.classify(d, c_factory, classify_overwrite)

    log.info("Initializing uuid-to-descriptor parallel map")
    #: :type: collections.Iterable[smqtk.representation.DescriptorElement]
    element_iter = parallel.parallel_map(
        descr_for_uuid, iter_uuids(),
        use_multiprocessing=p_use_multiprocessing,
        cores=p_index_extraction_cores,
        name="descr_for_uuid",
    )

    log.info("Initializing descriptor-to-classification parallel map")
    #: :type: collections.Iterable[smqtk.representation.ClassificationElement]
    classification_iter = parallel.parallel_map(
        classify_descr, element_iter,
        use_multiprocessing=p_use_multiprocessing,
        cores=p_classification_cores,
        name='classify_descr',
    )

    #
    # Write/Output files
    #

    c_labels = classifier.get_labels()

    def make_row(c):
        """
        :type c: smqtk.representation.ClassificationElement
        """
        c_m = c.get_classification()
        return [c.uuid] + [c_m[l] for l in c_labels]

    # column labels file
    log.info("Writing CSV column header file: %s", output_csv_header_filepath)
    file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath))
    with open(output_csv_header_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        w.writerow(['uuid'] + c_labels)

    # CSV file
    log.info("Writing CSV data file: %s", output_csv_filepath)
    file_utils.safe_create_dir(os.path.dirname(output_csv_filepath))
    r_state = [0] * 7
    with open(output_csv_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        for c in classification_iter:
            w.writerow(make_row(c))
            bin_utils.report_progress(log.info, r_state, 1.0)

    # Final report
    r_state[1] -= 1
    bin_utils.report_progress(log.info, r_state, 0)

    log.info("Done")
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    # - parallel_map UUIDs to load from the configured index
    # - classify iterated descriptors

    uuids_list_filepath = args.uuids_list
    output_csv_filepath = args.csv_data
    output_csv_header_filepath = args.csv_header
    classify_overwrite = config['utility']['classify_overwrite']

    p_use_multiprocessing = \
        config['utility']['parallel']['use_multiprocessing']
    p_index_extraction_cores = \
        config['utility']['parallel']['index_extraction_cores']
    p_classification_cores = \
        config['utility']['parallel']['classification_cores']

    if not uuids_list_filepath:
        raise ValueError("No uuids_list_filepath specified.")
    elif not os.path.isfile(uuids_list_filepath):
        raise ValueError("Given uuids_list_filepath did not point to a file.")
    if output_csv_header_filepath is None:
        raise ValueError("Need a path to save CSV header labels")
    if output_csv_filepath is None:
        raise ValueError("Need a path to save CSV data.")

    #
    # Initialize configured plugins
    #

    log.info("Initializing descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )

    log.info("Initializing classification factory")
    c_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory']
    )

    log.info("Initializing classifier")
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(
        config['plugins']['classifier'], get_classifier_impls()
    )

    #
    # Setup/Process
    #
    def iter_uuids():
        with open(uuids_list_filepath) as f:
            for l in f:
                yield l.strip()

    def descr_for_uuid(uuid):
        """
        :type uuid: collections.Hashable
        :rtype: smqtk.representation.DescriptorElement
        """
        return descriptor_index.get_descriptor(uuid)

    def classify_descr(d):
        """
        :type d: smqtk.representation.DescriptorElement
        :rtype: smqtk.representation.ClassificationElement
        """
        return classifier.classify(d, c_factory, classify_overwrite)

    log.info("Initializing uuid-to-descriptor parallel map")
    #: :type: collections.Iterable[smqtk.representation.DescriptorElement]
    element_iter = parallel.parallel_map(
        descr_for_uuid, iter_uuids(),
        use_multiprocessing=p_use_multiprocessing,
        cores=p_index_extraction_cores,
        name="descr_for_uuid",
    )

    log.info("Initializing descriptor-to-classification parallel map")
    #: :type: collections.Iterable[smqtk.representation.ClassificationElement]
    classification_iter = parallel.parallel_map(
        classify_descr, element_iter,
        use_multiprocessing=p_use_multiprocessing,
        cores=p_classification_cores,
        name='classify_descr',
    )

    #
    # Write/Output files
    #

    c_labels = classifier.get_labels()

    def make_row(e):
        """
        :type e: smqtk.representation.ClassificationElement
        """
        c_m = e.get_classification()
        return [e.uuid] + [c_m[l] for l in c_labels]

    # column labels file
    log.info("Writing CSV column header file: %s", output_csv_header_filepath)
    file_utils.safe_create_dir(os.path.dirname(output_csv_header_filepath))
    with open(output_csv_header_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        w.writerow(['uuid'] + [str(cl) for cl in c_labels])

    # CSV file
    log.info("Writing CSV data file: %s", output_csv_filepath)
    file_utils.safe_create_dir(os.path.dirname(output_csv_filepath))
    r_state = [0] * 7
    with open(output_csv_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        for c in classification_iter:
            w.writerow(make_row(c))
            bin_utils.report_progress(log.info, r_state, 1.0)

    # Final report
    r_state[1] -= 1
    bin_utils.report_progress(log.info, r_state, 0)

    log.info("Done")
Exemple #33
0
def main():
    description = """
    Utility for fetching remotely stored image data from the CDR ElasticSearch
    instance.

    Files will be transferred into the configured directory with the format::

        <output_dir>/<index>/<_type>/<id>.<type_extension>

    Configuration Notes:

        image_types
            This is a list of image MIMETYPE suffixes to include when querying
            the ElasticSearch instance. If all types should be considered, this
            should be set to an empty list.

        stored_http_auth
            This is only used for stored-data URLs and only if both a username
            and password is given.

        elastic_search
            batch_size
                The number of query hits to fetch at a time from the instance.

    """
    args, config = utility_main_helper(default_config, description,
                                       extend_parser)
    log = logging.getLogger(__name__)

    report_size = args.report_size
    crawled_after = args.crawled_after
    inserted_after = args.inserted_after

    #
    # Check config properties
    #
    m = mimetypes.MimeTypes()
    # non-strict types (see use of ``guess_extension`` above)
    m_img_types = set(m.types_map_inv[0].keys() + m.types_map_inv[1].keys())
    if not isinstance(config['image_types'], list):
        raise ValueError("The 'image_types' property was not set to a list.")
    for t in config['image_types']:
        if ('image/' + t) not in m_img_types:
            raise ValueError("Image type '%s' is not a valid image MIMETYPE "
                             "sub-type." % t)

    if not report_size and args.output_dir is None:
        raise ValueError("Require an output directory!")
    if not report_size and args.file_list is None:
        raise ValueError("Require an output CSV file path!")

    #
    # Initialize ElasticSearch stuff
    #
    es_auth = None
    if config['elastic_search']['username'] and config['elastic_search']['password']:
        es_auth = (config['elastic_search']['username'],
                   config['elastic_search']['password'])

    es = elasticsearch.Elasticsearch(
        config['elastic_search']['instance_address'],
        http_auth=es_auth,
        use_ssl=True, verify_certs=True,
        ca_certs=certifi.where(),
    )

    #
    # Query and Run
    #
    http_auth = None
    if config['stored_http_auth']['name'] and config['stored_http_auth']['pass']:
        http_auth = (config['stored_http_auth']['name'],
                     config['stored_http_auth']['pass'])

    ts_re = re.compile('(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2})Z')
    if crawled_after:
        m = ts_re.match(crawled_after)
        if m is None:
            raise ValueError("Given 'crawled-after' timestamp not in correct "
                             "format: '%s'" % crawled_after)
        crawled_after = datetime.datetime(*[int(e) for e in m.groups()])
    if inserted_after:
        m = ts_re.match(inserted_after)
        if m is None:
            raise ValueError("Given 'inserted-after' timestamp not in correct "
                             "format: '%s'" % inserted_after)
        inserted_after = datetime.datetime(*[int(e) for e in m.groups()])

    q = cdr_images_after(es, config['elastic_search']['index'],
                         config['image_types'], crawled_after, inserted_after)

    log.info("Query Size: %d", q[0:0].execute().hits.total)
    if report_size:
        exit(0)

    fetch_cdr_query_images(q, args.output_dir, args.file_list,
                           cores=int(config['parallel']['cores']),
                           stored_http_auth=http_auth,
                           batch_size=int(config['elastic_search']['batch_size']))
def classifier_kfold_validation():
    description = """
    Helper utility for cross validating a supervised classifier configuration.
    The classifier used should NOT be configured to save its model since this
    process requires us to train the classifier multiple times.

    Configuration
    -------------
    - plugins
        - supervised_classifier
            Supervised Classifier implementation configuration to use. This
            should not be set to use a persistent model if able.

        - descriptor_index
            Index to draw descriptors to classify from.

    - cross_validation
        - truth_labels
            Path to a CSV file containing descriptor UUID the truth label
            associations. This defines what descriptors are used from the given
            index. We error if any descriptor UUIDs listed here are not
            available in the given descriptor index. This file should be in
            [uuid, label] column format.

        - num_folds
            Number of folds to make for cross validation.

        - random_seed
            Optional fixed seed for the

        - classification_use_multiprocessing
            If we should use multiprocessing (vs threading) when classifying
            elements.

    - pr_curves
        - enabled
            If Precision/Recall plots should be generated.

        - show
            If we should attempt to show the graph after it has been generated
            (matplotlib).

        - output_directory
            Directory to save generated plots to. If None, we will not save
            plots. Otherwise we will create the directory (and required parent
            directories) if it does not exist.

        - file_prefix
            String prefix to prepend to standard plot file names.

    - roc_curves
        - enabled
            If ROC curves should be generated

        - show
            If we should attempt to show the plot after it has been generated
            (matplotlib).

        - output_directory
            Directory to save generated plots to. If None, we will not save
            plots. Otherwise we will create the directory (and required parent
            directories) if it does not exist.

        - file_prefix
            String prefix to prepend to standard plot file names.
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    #
    # Load configurations / Setup data
    #
    use_mp = config['cross_validation']['classification_use_multiprocessing']

    pr_enabled = config['pr_curves']['enabled']
    pr_output_dir = config['pr_curves']['output_directory']
    pr_file_prefix = config['pr_curves']['file_prefix'] or ''
    pr_show = config['pr_curves']['show']

    roc_enabled = config['roc_curves']['enabled']
    roc_output_dir = config['roc_curves']['output_directory']
    roc_file_prefix = config['roc_curves']['file_prefix'] or ''
    roc_show = config['roc_curves']['show']

    log.info("Initializing DescriptorIndex (%s)",
             config['plugins']['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )
    log.info("Loading classifier configuration")
    #: :type: dict
    classifier_config = config['plugins']['supervised_classifier']
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory']
    )

    log.info("Loading truth data")
    #: :type: list[str]
    uuids = []
    #: :type: list[str]
    truth_labels = []
    with open(config['cross_validation']['truth_labels']) as f:
        f_csv = csv.reader(f)
        for row in f_csv:
            uuids.append(row[0])
            truth_labels.append(row[1])
    #: :type: numpy.ndarray[str]
    uuids = numpy.array(uuids)
    #: :type: numpy.ndarray[str]
    truth_labels = numpy.array(truth_labels)

    #
    # Cross validation
    #
    kfolds = sklearn.cross_validation.StratifiedKFold(
        truth_labels, config['cross_validation']['num_folds'],
        random_state=config['cross_validation']['random_seed']
    )

    """
    Truth and classification probability results for test data per fold.
    Format:
        {
            0: {
                '<label>':  {
                    "truth": [...],   # Parallel truth and classification
                    "proba": [...],   # probability values
                },
                ...
            },
            ...
        }
    """
    fold_data = {}

    i = 0
    for train, test in kfolds:
        log.info("Fold %d", i)
        log.info("-- %d training examples", len(train))
        log.info("-- %d test examples", len(test))
        fold_data[i] = {}

        log.info("-- creating classifier")
        #: :type: SupervisedClassifier
        classifier = plugin.from_plugin_config(
            classifier_config,
            get_supervised_classifier_impls()
        )

        log.info("-- gathering descriptors")
        #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
        pos_map = {}
        for idx in train:
            if truth_labels[idx] not in pos_map:
                pos_map[truth_labels[idx]] = []
            pos_map[truth_labels[idx]].append(
                descriptor_index.get_descriptor(uuids[idx])
            )

        log.info("-- Training classifier")
        classifier.train(pos_map)

        log.info("-- Classifying test set")
        m = classifier.classify_async(
            (descriptor_index.get_descriptor(uuids[idx]) for idx in test),
            classification_factory,
            use_multiprocessing=use_mp, ri=1.0
        )
        uuid2c = dict((d.uuid(), c.get_classification())
                      for d, c in m.iteritems())

        log.info("-- Pairing truth and computed probabilities")
        # Only considering positive labels
        for t_label in pos_map:
            fold_data[i][t_label] = {
                "truth": [l == t_label for l in truth_labels[test]],
                "proba": [uuid2c[uuid][t_label] for uuid in uuids[test]]
            }

        i += 1

    #
    # Curve generation
    #
    if pr_enabled:
        make_pr_curves(fold_data, pr_output_dir, pr_file_prefix, pr_show)
    if roc_enabled:
        make_roc_curves(fold_data, roc_output_dir, roc_file_prefix, roc_show)
Exemple #35
0
def classifier_kfold_validation():
    description = """
    Helper utility for cross validating a supervised classifier configuration.
    The classifier used should NOT be configured to save its model since this
    process requires us to train the classifier multiple times.

    Configuration
    -------------
    - plugins
        - supervised_classifier
            Supervised Classifier implementation configuration to use. This
            should not be set to use a persistent model if able.

        - descriptor_index
            Index to draw descriptors to classify from.

    - cross_validation
        - truth_labels
            Path to a CSV file containing descriptor UUID the truth label
            associations. This defines what descriptors are used from the given
            index. We error if any descriptor UUIDs listed here are not
            available in the given descriptor index. This file should be in
            [uuid, label] column format.

        - num_folds
            Number of folds to make for cross validation.

        - random_seed
            Optional fixed seed for the

        - classification_use_multiprocessing
            If we should use multiprocessing (vs threading) when classifying
            elements.

    - pr_curves
        - enabled
            If Precision/Recall plots should be generated.

        - show
            If we should attempt to show the graph after it has been generated
            (matplotlib).

        - output_directory
            Directory to save generated plots to. If None, we will not save
            plots. Otherwise we will create the directory (and required parent
            directories) if it does not exist.

        - file_prefix
            String prefix to prepend to standard plot file names.

    - roc_curves
        - enabled
            If ROC curves should be generated

        - show
            If we should attempt to show the plot after it has been generated
            (matplotlib).

        - output_directory
            Directory to save generated plots to. If None, we will not save
            plots. Otherwise we will create the directory (and required parent
            directories) if it does not exist.

        - file_prefix
            String prefix to prepend to standard plot file names.
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    #
    # Load configurations / Setup data
    #
    use_mp = config['cross_validation']['classification_use_multiprocessing']

    pr_enabled = config['pr_curves']['enabled']
    pr_output_dir = config['pr_curves']['output_directory']
    pr_file_prefix = config['pr_curves']['file_prefix'] or ''
    pr_show = config['pr_curves']['show']

    roc_enabled = config['roc_curves']['enabled']
    roc_output_dir = config['roc_curves']['output_directory']
    roc_file_prefix = config['roc_curves']['file_prefix'] or ''
    roc_show = config['roc_curves']['show']

    log.info("Initializing DescriptorIndex (%s)",
             config['plugins']['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'], get_descriptor_index_impls())
    log.info("Loading classifier configuration")
    #: :type: dict
    classifier_config = config['plugins']['supervised_classifier']
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])

    log.info("Loading truth data")
    #: :type: list[str]
    uuids = []
    #: :type: list[str]
    truth_labels = []
    with open(config['cross_validation']['truth_labels']) as f:
        f_csv = csv.reader(f)
        for row in f_csv:
            uuids.append(row[0])
            truth_labels.append(row[1])
    #: :type: numpy.ndarray[str]
    uuids = numpy.array(uuids)
    #: :type: numpy.ndarray[str]
    truth_labels = numpy.array(truth_labels)

    #
    # Cross validation
    #
    kfolds = sklearn.cross_validation.StratifiedKFold(
        truth_labels,
        config['cross_validation']['num_folds'],
        random_state=config['cross_validation']['random_seed'])
    """
    Truth and classification probability results for test data per fold.
    Format:
        {
            0: {
                '<label>':  {
                    "truth": [...],   # Parallel truth and classification
                    "proba": [...],   # probability values
                },
                ...
            },
            ...
        }
    """
    fold_data = {}

    i = 0
    for train, test in kfolds:
        log.info("Fold %d", i)
        log.info("-- %d training examples", len(train))
        log.info("-- %d test examples", len(test))
        fold_data[i] = {}

        log.info("-- creating classifier")
        #: :type: SupervisedClassifier
        classifier = plugin.from_plugin_config(
            classifier_config, get_supervised_classifier_impls())

        log.info("-- gathering descriptors")
        #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
        pos_map = {}
        for idx in train:
            if truth_labels[idx] not in pos_map:
                pos_map[truth_labels[idx]] = []
            pos_map[truth_labels[idx]].append(
                descriptor_index.get_descriptor(uuids[idx]))

        log.info("-- Training classifier")
        classifier.train(pos_map)

        log.info("-- Classifying test set")
        m = classifier.classify_async(
            (descriptor_index.get_descriptor(uuids[idx]) for idx in test),
            classification_factory,
            use_multiprocessing=use_mp,
            ri=1.0)
        uuid2c = dict(
            (d.uuid(), c.get_classification()) for d, c in m.iteritems())

        log.info("-- Pairing truth and computed probabilities")
        # Only considering positive labels
        for t_label in pos_map:
            fold_data[i][t_label] = {
                "truth": [l == t_label for l in truth_labels[test]],
                "proba": [uuid2c[uuid][t_label] for uuid in uuids[test]]
            }

        i += 1

    #
    # Curve generation
    #
    if pr_enabled:
        make_pr_curves(fold_data, pr_output_dir, pr_file_prefix, pr_show)
    if roc_enabled:
        make_roc_curves(fold_data, roc_output_dir, roc_file_prefix, roc_show)
def classifier_kfold_validation():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #
    # Load configurations / Setup data
    #
    use_mp = config['cross_validation']['classification_use_multiprocessing']

    pr_enabled = config['pr_curves']['enabled']
    pr_output_dir = config['pr_curves']['output_directory']
    pr_file_prefix = config['pr_curves']['file_prefix'] or ''
    pr_show = config['pr_curves']['show']

    roc_enabled = config['roc_curves']['enabled']
    roc_output_dir = config['roc_curves']['output_directory']
    roc_file_prefix = config['roc_curves']['file_prefix'] or ''
    roc_show = config['roc_curves']['show']

    log.info("Initializing DescriptorIndex (%s)",
             config['plugins']['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )
    log.info("Loading classifier configuration")
    #: :type: dict
    classifier_config = config['plugins']['supervised_classifier']

    # Always use in-memory ClassificationElement since we are retraining the
    # classifier and don't want possible element caching
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory(
        MemoryClassificationElement, {}
    )

    log.info("Loading truth data")
    #: :type: list[str]
    uuids = []
    #: :type: list[str]
    truth_labels = []
    with open(config['cross_validation']['truth_labels']) as f:
        f_csv = csv.reader(f)
        for row in f_csv:
            uuids.append(row[0])
            truth_labels.append(row[1])
    #: :type: numpy.ndarray[str]
    uuids = numpy.array(uuids)
    #: :type: numpy.ndarray[str]
    truth_labels = numpy.array(truth_labels)

    #
    # Cross validation
    #
    kfolds = sklearn.model_selection.StratifiedKFold(
        n_splits=config['cross_validation']['num_folds'],
        shuffle=True,
        random_state=config['cross_validation']['random_seed'],
    ).split(numpy.zeros(len(truth_labels)), truth_labels)

    """
    Truth and classification probability results for test data per fold.
    Format:
        {
            0: {
                '<label>':  {
                    "truth": [...],   # Parallel truth and classification
                    "proba": [...],   # probability values
                },
                ...
            },
            ...
        }
    """
    fold_data = {}

    i = 0
    for train, test in kfolds:
        log.info("Fold %d", i)
        log.info("-- %d training examples", len(train))
        log.info("-- %d test examples", len(test))
        fold_data[i] = {}

        log.info("-- creating classifier")
        #: :type: SupervisedClassifier
        classifier = plugin.from_plugin_config(
            classifier_config,
            get_supervised_classifier_impls()
        )

        log.info("-- gathering descriptors")
        #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
        pos_map = {}
        for idx in train:
            if truth_labels[idx] not in pos_map:
                pos_map[truth_labels[idx]] = []
            pos_map[truth_labels[idx]].append(
                descriptor_index.get_descriptor(uuids[idx])
            )

        log.info("-- Training classifier")
        classifier.train(pos_map)

        log.info("-- Classifying test set")
        m = classifier.classify_async(
            (descriptor_index.get_descriptor(uuids[idx]) for idx in test),
            classification_factory,
            use_multiprocessing=use_mp, ri=1.0
        )
        uuid2c = dict((d.uuid(), c.get_classification())
                      for d, c in six.iteritems(m))

        log.info("-- Pairing truth and computed probabilities")
        # Only considering positive labels
        for t_label in pos_map:
            fold_data[i][t_label] = {
                "truth": [L == t_label for L in truth_labels[test]],
                "proba": [uuid2c[uuid][t_label] for uuid in uuids[test]]
            }

        i += 1

    #
    # Curve generation
    #
    if pr_enabled:
        make_pr_curves(fold_data, pr_output_dir, pr_file_prefix, pr_show)
    if roc_enabled:
        make_roc_curves(fold_data, roc_output_dir, roc_file_prefix, roc_show)
def main():
    description = """
    Utility script to transform a set of descriptors, specified by UUID, with
    matching class labels, to a test file usable by libSVM utilities for
    train/test experiments.

    The input CSV file is assumed to be of the format:

        uuid,label
        ...

    This is the same as the format requested for other scripts like
    ``classifier_model_validation.py``.

    This is very useful for searching for -c and -g parameter values for a
    training sample of data using the ``tools/grid.py`` script, found in the
    libSVM source tree. For example:

        <smqtk_source>/TPL/libsvm-3.1-custom/tools/grid.py \\
            -log2c -5,15,2 -log2c 3,-15,-2 -v 5 -out libsvm.grid.out \\
            -png libsvm.grid.png -t 0 -w1 3.46713615023 -w2 12.2613240418 \\
            output_of_this_script.txt
    """
    args, config = bin_utils.utility_main_helper(default_config, description,
                                                 extend_parser)
    log = logging.getLogger(__name__)

    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )

    labels_filepath = args.f
    output_filepath = args.o

    # Run through labeled UUIDs in input file, getting the descriptor from the
    # configured index, applying the appropriate integer label and then writing
    # the formatted line out to the output file.
    input_uuid_labels = csv.reader(open(labels_filepath))

    with open(output_filepath, 'w') as ofile:
        label2int = {}
        next_int = 1
        uuids, labels = zip(*input_uuid_labels)

        log.info("Scanning input descriptors and labels")
        for i, (l, d) in enumerate(
                    itertools.izip(labels,
                                   descriptor_index.get_many_descriptors(uuids))
                ):
            log.debug("%d %s", i, d.uuid())
            if l not in label2int:
                label2int[l] = next_int
                next_int += 1
            ofile.write(
                "%d " % label2int[l] +
                ' '.join(["%d:%.12f" % (j+1, f)
                          for j, f in enumerate(d.vector())
                          if f != 0.0]) +
                '\n'
            )

    log.info("Integer label association:")
    for i, l in sorted((i, l) for l, i in label2int.iteritems()):
        log.info('\t%d :: %s', i, l)
Exemple #38
0
def main():
    args = cli_parser().parse_args()
    config = utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    report_size = args.report_size
    crawled_after = args.crawled_after
    inserted_after = args.inserted_after

    #
    # Check config properties
    #
    m = mimetypes.MimeTypes()
    # non-strict types (see use of ``guess_extension`` above)
    m_img_types = set(m.types_map_inv[0].keys() + m.types_map_inv[1].keys())
    if not isinstance(config['image_types'], list):
        raise ValueError("The 'image_types' property was not set to a list.")
    for t in config['image_types']:
        if ('image/' + t) not in m_img_types:
            raise ValueError("Image type '%s' is not a valid image MIMETYPE "
                             "sub-type." % t)

    if not report_size and args.output_dir is None:
        raise ValueError("Require an output directory!")
    if not report_size and args.file_list is None:
        raise ValueError("Require an output CSV file path!")

    #
    # Initialize ElasticSearch stuff
    #
    es_auth = None
    if config['elastic_search']['username'] and config['elastic_search']['password']:
        es_auth = (config['elastic_search']['username'],
                   config['elastic_search']['password'])

    es = elasticsearch.Elasticsearch(
        config['elastic_search']['instance_address'],
        http_auth=es_auth,
        use_ssl=True, verify_certs=True,
        ca_certs=certifi.where(),
    )

    #
    # Query and Run
    #
    http_auth = None
    if config['stored_http_auth']['name'] and config['stored_http_auth']['pass']:
        http_auth = (config['stored_http_auth']['name'],
                     config['stored_http_auth']['pass'])

    ts_re = re.compile('(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2})Z')
    if crawled_after:
        m = ts_re.match(crawled_after)
        if m is None:
            raise ValueError("Given 'crawled-after' timestamp not in correct "
                             "format: '%s'" % crawled_after)
        crawled_after = datetime.datetime(*[int(e) for e in m.groups()])
    if inserted_after:
        m = ts_re.match(inserted_after)
        if m is None:
            raise ValueError("Given 'inserted-after' timestamp not in correct "
                             "format: '%s'" % inserted_after)
        inserted_after = datetime.datetime(*[int(e) for e in m.groups()])

    q = cdr_images_after(es, config['elastic_search']['index'],
                         config['image_types'], crawled_after, inserted_after)

    log.info("Query Size: %d", q[0:0].execute().hits.total)
    if report_size:
        exit(0)

    fetch_cdr_query_images(q, args.output_dir, args.file_list,
                           cores=int(config['parallel']['cores']),
                           stored_http_auth=http_auth,
                           batch_size=int(config['elastic_search']['batch_size']))
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #
    # Load configuration contents
    #
    uuid_list_filepath = args.uuids_list
    report_interval = config['utility']['report_interval']
    use_multiprocessing = config['utility']['use_multiprocessing']

    #
    # Checking input parameters
    #
    if (uuid_list_filepath is not None) and \
            not os.path.isfile(uuid_list_filepath):
        raise ValueError("UUIDs list file does not exist!")

    #
    # Loading stuff
    #
    log.info("Loading descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'], get_descriptor_index_impls())
    log.info("Loading LSH functor")
    #: :type: smqtk.algorithms.LshFunctor
    lsh_functor = plugin.from_plugin_config(config['plugins']['lsh_functor'],
                                            get_lsh_functor_impls())
    log.info("Loading Key/Value store")
    #: :type: smqtk.representation.KeyValueStore
    hash2uuids_kvstore = plugin.from_plugin_config(
        config['plugins']['hash2uuid_kvstore'], get_key_value_store_impls())

    # Iterate either over what's in the file given, or everything in the
    # configured index.
    def iter_uuids():
        if uuid_list_filepath:
            log.info("Using UUIDs list file")
            with open(uuid_list_filepath) as f:
                for l in f:
                    yield l.strip()
        else:
            log.info("Using all UUIDs resent in descriptor index")
            for k in descriptor_index.keys():
                yield k

    #
    # Compute codes
    #
    log.info("Starting hash code computation")
    kv_update = {}
    for uuid, hash_int in \
            compute_hash_codes(uuids_for_processing(iter_uuids(),
                                                    hash2uuids_kvstore),
                               descriptor_index, lsh_functor,
                               report_interval,
                               use_multiprocessing, True):
        # Get original value in KV-store if not in update dict.
        if hash_int not in kv_update:
            kv_update[hash_int] = hash2uuids_kvstore.get(hash_int, set())
        kv_update[hash_int] |= {uuid}

    if kv_update:
        log.info("Updating KV store... (%d keys)" % len(kv_update))
        hash2uuids_kvstore.add_many(kv_update)

    log.info("Done")
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #
    # Initialize stuff from configuration
    #
    #: :type: smqtk.algorithms.Classifier
    classifier = plugin.from_plugin_config(config['plugins']['classifier'],
                                           get_classifier_impls())
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'], get_descriptor_index_impls())

    uuid2label_filepath = config['utility']['csv_filepath']
    do_train = config['utility']['train']
    output_uuid_cm = config['utility']['output_uuid_confusion_matrix']
    plot_filepath_pr = config['utility']['output_plot_pr']
    plot_filepath_roc = config['utility']['output_plot_roc']
    plot_filepath_cm = config['utility']['output_plot_confusion_matrix']
    plot_ci = config['utility']['curve_confidence_interval']
    plot_ci_alpha = config['utility']['curve_confidence_interval_alpha']

    #
    # Construct mapping of label to the DescriptorElement instances for that
    # described by that label.
    #
    log.info("Loading descriptors by UUID")

    def iter_uuid_label():
        """ Iterate through UUIDs in specified file """
        with open(uuid2label_filepath) as uuid2label_file:
            reader = csv.reader(uuid2label_file)
            for r in reader:
                # TODO: This will need to be updated to handle multiple labels
                #       per descriptor.
                yield r[0], r[1]

    def get_descr(r):
        """ Fetch descriptors from configured index """
        uuid, truth_label = r
        return truth_label, descriptor_index.get_descriptor(uuid)

    tlabel_element_iter = parallel.parallel_map(
        get_descr,
        iter_uuid_label(),
        name="cmv_get_descriptors",
        use_multiprocessing=True,
        cores=config['parallelism']['descriptor_fetch_cores'],
    )

    # Map of truth labels to descriptors of labeled data
    #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
    tlabel2descriptors = {}
    for tlabel, d in tlabel_element_iter:
        tlabel2descriptors.setdefault(tlabel, []).append(d)

    # Train classifier if the one given has a ``train`` method and training
    # was turned enabled.
    if do_train:
        if isinstance(classifier, SupervisedClassifier):
            log.info("Training classifier model")
            classifier.train(tlabel2descriptors)
            exit(0)
        else:
            ValueError("Configured classifier is not a SupervisedClassifier "
                       "type and does not support training.")

    #
    # Apply classifier to descriptors for predictions
    #

    # Truth label to predicted classification results
    #: :type: dict[str, set[smqtk.representation.ClassificationElement]]
    tlabel2classifications = {}
    for tlabel, descriptors in tlabel2descriptors.items():
        tlabel2classifications[tlabel] = \
            set(classifier.classify_async(
                descriptors, classification_factory,
                use_multiprocessing=True,
                procs=config['parallelism']['classification_cores'],
                ri=1.0,
            ).values())
    log.info("Truth label counts:")
    for l in sorted(tlabel2classifications):
        log.info("  %s :: %d", l, len(tlabel2classifications[l]))

    #
    # Confusion Matrix
    #
    conf_mat, labels = gen_confusion_matrix(tlabel2classifications)
    log.info("Confusion_matrix")
    log_cm(log.info, conf_mat, labels)
    if plot_filepath_cm:
        plot_cm(conf_mat, labels, plot_filepath_cm)

    # CM of descriptor UUIDs to output json
    if output_uuid_cm:
        # Top dictionary keys are true labels, inner dictionary keys are UUID
        # predicted labels.
        log.info("Computing UUID Confusion Matrix")
        #: :type: dict[str, dict[str, set | list]]
        uuid_cm = {}
        for tlabel in tlabel2classifications:
            uuid_cm[tlabel] = collections.defaultdict(set)
            for c in tlabel2classifications[tlabel]:
                uuid_cm[tlabel][c.max_label()].add(c.uuid)
            # convert sets to lists
            for plabel in uuid_cm[tlabel]:
                uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel])
        with open(output_uuid_cm, 'w') as f:
            log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm)
            json.dump(uuid_cm, f, indent=2, separators=(',', ': '))

    #
    # Create PR/ROC curves via scikit learn tools
    #
    if plot_filepath_pr:
        log.info("Making PR curve")
        make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci,
                       plot_ci_alpha)
    if plot_filepath_roc:
        log.info("Making ROC curve")
        make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci,
                        plot_ci_alpha)
Exemple #41
0
def main():
    description = """
    Compute LSH hash codes based on the provided functor on specific
    descriptors from the configured index given a file-list of UUIDs.

    When using an input file-list of UUIDs, we require that the UUIDs of
    indexed descriptors be strings, or equality comparable to the UUIDs' string
    representation.

    This script can be used to live update the ``hash2uuid_cache_filepath``
    model file for the ``LSHNearestNeighborIndex`` algorithm as output
    dictionary format is the same as used by that implementation.
    """
    args, config = bin_utils.utility_main_helper(default_config, description,
                                                 extend_parser)
    log = logging.getLogger(__name__)

    #
    # Load configuration contents
    #
    uuid_list_filepath = args.uuids_list
    hash2uuids_input_filepath = args.input_hash2uuids
    hash2uuids_output_filepath = args.output_hash2uuids
    report_interval = config['utility']['report_interval']
    use_multiprocessing = config['utility']['use_multiprocessing']
    pickle_protocol = config['utility']['pickle_protocol']

    #
    # Checking parameters
    #
    if not hash2uuids_output_filepath:
        raise ValueError("No hash2uuids map output file provided!")

    #
    # Loading stuff
    #
    log.info("Loading descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'], get_descriptor_index_impls())
    log.info("Loading LSH functor")
    #: :type: smqtk.algorithms.LshFunctor
    lsh_functor = plugin.from_plugin_config(config['plugins']['lsh_functor'],
                                            get_lsh_functor_impls())

    def iter_uuids():
        if uuid_list_filepath:
            log.info("Using UUIDs list file")
            with open(uuid_list_filepath) as f:
                for l in f:
                    yield l.strip()
        else:
            log.info("Using all UUIDs resent in descriptor index")
            for k in descriptor_index.iterkeys():
                yield k

    # load map if it exists, else start with empty dictionary
    if hash2uuids_input_filepath and os.path.isfile(hash2uuids_input_filepath):
        log.info("Loading hash2uuids mapping")
        with open(hash2uuids_input_filepath) as f:
            hash2uuids = cPickle.load(f)
    else:
        log.info("Creating new hash2uuids mapping for output")
        hash2uuids = {}

    #
    # Compute codes
    #
    log.info("Starting hash code computation")
    compute_hash_codes(
        uuids_for_processing(iter_uuids(), hash2uuids),
        descriptor_index,
        lsh_functor,
        hash2uuids,
        report_interval=report_interval,
        use_mp=use_multiprocessing,
    )

    #
    # Output results
    #
    tmp_output_filepath = hash2uuids_output_filepath + '.WRITING'
    log.info("Writing hash-to-uuids map to disk: %s", tmp_output_filepath)
    file_utils.safe_create_dir(os.path.dirname(hash2uuids_output_filepath))
    with open(tmp_output_filepath, 'wb') as f:
        cPickle.dump(hash2uuids, f, pickle_protocol)
    log.info("Moving on top of input: %s", hash2uuids_output_filepath)
    os.rename(tmp_output_filepath, hash2uuids_output_filepath)
    log.info("Done")
Exemple #42
0
def main():
    args = cli_parser().parse_args()
    config = utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    report_size = args.report_size
    crawled_after = args.crawled_after
    inserted_after = args.inserted_after

    #
    # Check config properties
    #
    m = mimetypes.MimeTypes()
    # non-strict types (see use of ``guess_extension`` above)
    m_img_types = set(m.types_map_inv[0].keys() + m.types_map_inv[1].keys())
    if not isinstance(config['image_types'], list):
        raise ValueError("The 'image_types' property was not set to a list.")
    for t in config['image_types']:
        if ('image/' + t) not in m_img_types:
            raise ValueError("Image type '%s' is not a valid image MIMETYPE "
                             "sub-type." % t)

    if not report_size and args.output_dir is None:
        raise ValueError("Require an output directory!")
    if not report_size and args.file_list is None:
        raise ValueError("Require an output CSV file path!")

    #
    # Initialize ElasticSearch stuff
    #
    es_auth = None
    if config['elastic_search']['username'] and config['elastic_search'][
            'password']:
        es_auth = (config['elastic_search']['username'],
                   config['elastic_search']['password'])

    es = elasticsearch.Elasticsearch(
        config['elastic_search']['instance_address'],
        http_auth=es_auth,
        use_ssl=True,
        verify_certs=True,
        ca_certs=certifi.where(),
    )

    #
    # Query and Run
    #
    http_auth = None
    if config['stored_http_auth']['name'] and config['stored_http_auth'][
            'pass']:
        http_auth = (config['stored_http_auth']['name'],
                     config['stored_http_auth']['pass'])

    ts_re = re.compile('(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2})Z')
    if crawled_after:
        m = ts_re.match(crawled_after)
        if m is None:
            raise ValueError("Given 'crawled-after' timestamp not in correct "
                             "format: '%s'" % crawled_after)
        crawled_after = datetime.datetime(*[int(e) for e in m.groups()])
    if inserted_after:
        m = ts_re.match(inserted_after)
        if m is None:
            raise ValueError("Given 'inserted-after' timestamp not in correct "
                             "format: '%s'" % inserted_after)
        inserted_after = datetime.datetime(*[int(e) for e in m.groups()])

    q = cdr_images_after(es, config['elastic_search']['index'],
                         config['image_types'], crawled_after, inserted_after)

    log.info("Query Size: %d", q[0:0].execute().hits.total)
    if report_size:
        exit(0)

    fetch_cdr_query_images(q,
                           args.output_dir,
                           args.file_list,
                           cores=int(config['parallel']['cores']),
                           stored_http_auth=http_auth,
                           batch_size=int(
                               config['elastic_search']['batch_size']))
Exemple #43
0
def main():
    description = """
    Utility for fetching remotely stored image data from the CDR ElasticSearch
    instance.

    Files will be transferred into the configured directory with the format::

        <output_dir>/<index>/<_type>/<id>.<type_extension>

    Configuration Notes:

        image_types
            This is a list of image MIMETYPE suffixes to include when querying
            the ElasticSearch instance. If all types should be considered, this
            should be set to an empty list.

        stored_http_auth
            This is only used for stored-data URLs and only if both a username
            and password is given.

        elastic_search
            batch_size
                The number of query hits to fetch at a time from the instance.

    """
    args, config = utility_main_helper(default_config, description,
                                       extend_parser)
    log = logging.getLogger(__name__)

    report_size = args.report_size
    crawled_after = args.crawled_after
    inserted_after = args.inserted_after

    #
    # Check config properties
    #
    m = mimetypes.MimeTypes()
    # non-strict types (see use of ``guess_extension`` above)
    m_img_types = set(m.types_map_inv[0].keys() + m.types_map_inv[1].keys())
    if not isinstance(config['image_types'], list):
        raise ValueError("The 'image_types' property was not set to a list.")
    for t in config['image_types']:
        if ('image/' + t) not in m_img_types:
            raise ValueError("Image type '%s' is not a valid image MIMETYPE "
                             "sub-type." % t)

    if not report_size and args.output_dir is None:
        raise ValueError("Require an output directory!")
    if not report_size and args.file_list is None:
        raise ValueError("Require an output CSV file path!")

    #
    # Initialize ElasticSearch stuff
    #
    es_auth = None
    if config['elastic_search']['username'] and config['elastic_search'][
            'password']:
        es_auth = (config['elastic_search']['username'],
                   config['elastic_search']['password'])

    es = elasticsearch.Elasticsearch(
        config['elastic_search']['instance_address'],
        http_auth=es_auth,
        use_ssl=True,
        verify_certs=True,
        ca_certs=certifi.where(),
    )

    #
    # Query and Run
    #
    http_auth = None
    if config['stored_http_auth']['name'] and config['stored_http_auth'][
            'pass']:
        http_auth = (config['stored_http_auth']['name'],
                     config['stored_http_auth']['pass'])

    ts_re = re.compile('(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2})Z')
    if crawled_after:
        m = ts_re.match(crawled_after)
        if m is None:
            raise ValueError("Given 'crawled-after' timestamp not in correct "
                             "format: '%s'" % crawled_after)
        crawled_after = datetime.datetime(*[int(e) for e in m.groups()])
    if inserted_after:
        m = ts_re.match(inserted_after)
        if m is None:
            raise ValueError("Given 'inserted-after' timestamp not in correct "
                             "format: '%s'" % inserted_after)
        inserted_after = datetime.datetime(*[int(e) for e in m.groups()])

    q = cdr_images_after(es, config['elastic_search']['index'],
                         config['image_types'], crawled_after, inserted_after)

    log.info("Query Size: %d", q[0:0].execute().hits.total)
    if report_size:
        exit(0)

    fetch_cdr_query_images(q,
                           args.output_dir,
                           args.file_list,
                           cores=int(config['parallel']['cores']),
                           stored_http_auth=http_auth,
                           batch_size=int(
                               config['elastic_search']['batch_size']))
def classifier_kfold_validation():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #
    # Load configurations / Setup data
    #
    use_mp = config['cross_validation']['classification_use_multiprocessing']

    pr_enabled = config['pr_curves']['enabled']
    pr_output_dir = config['pr_curves']['output_directory']
    pr_file_prefix = config['pr_curves']['file_prefix'] or ''
    pr_show = config['pr_curves']['show']

    roc_enabled = config['roc_curves']['enabled']
    roc_output_dir = config['roc_curves']['output_directory']
    roc_file_prefix = config['roc_curves']['file_prefix'] or ''
    roc_show = config['roc_curves']['show']

    log.info("Initializing DescriptorIndex (%s)",
             config['plugins']['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'], get_descriptor_index_impls())
    log.info("Loading classifier configuration")
    #: :type: dict
    classifier_config = config['plugins']['supervised_classifier']
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])

    log.info("Loading truth data")
    #: :type: list[str]
    uuids = []
    #: :type: list[str]
    truth_labels = []
    with open(config['cross_validation']['truth_labels']) as f:
        f_csv = csv.reader(f)
        for row in f_csv:
            uuids.append(row[0])
            truth_labels.append(row[1])
    #: :type: numpy.ndarray[str]
    uuids = numpy.array(uuids)
    #: :type: numpy.ndarray[str]
    truth_labels = numpy.array(truth_labels)

    #
    # Cross validation
    #
    kfolds = sklearn.cross_validation.StratifiedKFold(
        truth_labels,
        config['cross_validation']['num_folds'],
        random_state=config['cross_validation']['random_seed'])
    """
    Truth and classification probability results for test data per fold.
    Format:
        {
            0: {
                '<label>':  {
                    "truth": [...],   # Parallel truth and classification
                    "proba": [...],   # probability values
                },
                ...
            },
            ...
        }
    """
    fold_data = {}

    i = 0
    for train, test in kfolds:
        log.info("Fold %d", i)
        log.info("-- %d training examples", len(train))
        log.info("-- %d test examples", len(test))
        fold_data[i] = {}

        log.info("-- creating classifier")
        #: :type: SupervisedClassifier
        classifier = plugin.from_plugin_config(
            classifier_config, get_supervised_classifier_impls())

        log.info("-- gathering descriptors")
        #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
        pos_map = {}
        for idx in train:
            if truth_labels[idx] not in pos_map:
                pos_map[truth_labels[idx]] = []
            pos_map[truth_labels[idx]].append(
                descriptor_index.get_descriptor(uuids[idx]))

        log.info("-- Training classifier")
        classifier.train(pos_map)

        log.info("-- Classifying test set")
        m = classifier.classify_async(
            (descriptor_index.get_descriptor(uuids[idx]) for idx in test),
            classification_factory,
            use_multiprocessing=use_mp,
            ri=1.0)
        uuid2c = dict(
            (d.uuid(), c.get_classification()) for d, c in m.iteritems())

        log.info("-- Pairing truth and computed probabilities")
        # Only considering positive labels
        for t_label in pos_map:
            fold_data[i][t_label] = {
                "truth": [l == t_label for l in truth_labels[test]],
                "proba": [uuid2c[uuid][t_label] for uuid in uuids[test]]
            }

        i += 1

    #
    # Curve generation
    #
    if pr_enabled:
        make_pr_curves(fold_data, pr_output_dir, pr_file_prefix, pr_show)
    if roc_enabled:
        make_roc_curves(fold_data, roc_output_dir, roc_file_prefix, roc_show)
Exemple #45
0
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #
    # Load configuration contents
    #
    uuid_list_filepath = args.uuids_list
    report_interval = config['utility']['report_interval']
    use_multiprocessing = config['utility']['use_multiprocessing']

    #
    # Checking input parameters
    #
    if (uuid_list_filepath is not None) and \
            not os.path.isfile(uuid_list_filepath):
        raise ValueError("UUIDs list file does not exist!")

    #
    # Loading stuff
    #
    log.info("Loading descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )
    log.info("Loading LSH functor")
    #: :type: smqtk.algorithms.LshFunctor
    lsh_functor = plugin.from_plugin_config(
        config['plugins']['lsh_functor'],
        get_lsh_functor_impls()
    )
    log.info("Loading Key/Value store")
    #: :type: smqtk.representation.KeyValueStore
    hash2uuids_kvstore = plugin.from_plugin_config(
        config['plugins']['hash2uuid_kvstore'],
        get_key_value_store_impls()
    )

    # Iterate either over what's in the file given, or everything in the
    # configured index.
    def iter_uuids():
        if uuid_list_filepath:
            log.info("Using UUIDs list file")
            with open(uuid_list_filepath) as f:
                for l in f:
                    yield l.strip()
        else:
            log.info("Using all UUIDs resent in descriptor index")
            for k in descriptor_index.keys():
                yield k

    #
    # Compute codes
    #
    log.info("Starting hash code computation")
    kv_update = {}
    for uuid, hash_int in \
            compute_hash_codes(uuids_for_processing(iter_uuids(),
                                                    hash2uuids_kvstore),
                               descriptor_index, lsh_functor,
                               report_interval,
                               use_multiprocessing, True):
        # Get original value in KV-store if not in update dict.
        if hash_int not in kv_update:
            kv_update[hash_int] = hash2uuids_kvstore.get(hash_int, set())
        kv_update[hash_int] |= {uuid}

    if kv_update:
        log.info("Updating KV store... (%d keys)" % len(kv_update))
        hash2uuids_kvstore.add_many(kv_update)

    log.info("Done")