Beispiel #1
0
def get_default_config():
    return {
        "descriptor_factory":
        DescriptorElementFactory.get_default_config(),
        "descriptor_generator":
        make_default_config(DescriptorGenerator.get_impls()),
        "classification_factory":
        ClassificationElementFactory.get_default_config(),
        "classifier":
        make_default_config(Classifier.get_impls()),
    }
def default_config():
    return {
        "utility": {
            "classify_overwrite": False,
            "parallel": {
                "use_multiprocessing": False,
                "index_extraction_cores": None,
                "classification_cores": None,
            }
        },
        "plugins": {
            "classifier":
            make_default_config(Classifier.get_impls()),
            "classification_factory":
            make_default_config(ClassificationElement.get_impls()),
            "descriptor_set":
            make_default_config(DescriptorSet.get_impls()),
        }
    }
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    # - parallel_map UUIDs to load from the configured index
    # - classify iterated descriptors

    uuids_list_filepath = args.uuids_list
    output_csv_filepath = args.csv_data
    output_csv_header_filepath = args.csv_header
    classify_overwrite = config['utility']['classify_overwrite']

    p_use_multiprocessing = \
        config['utility']['parallel']['use_multiprocessing']
    p_index_extraction_cores = \
        config['utility']['parallel']['index_extraction_cores']
    p_classification_cores = \
        config['utility']['parallel']['classification_cores']

    if not uuids_list_filepath:
        raise ValueError("No uuids_list_filepath specified.")
    elif not os.path.isfile(uuids_list_filepath):
        raise ValueError("Given uuids_list_filepath did not point to a file.")
    if output_csv_header_filepath is None:
        raise ValueError("Need a path to save CSV header labels")
    if output_csv_filepath is None:
        raise ValueError("Need a path to save CSV data.")

    #
    # Initialize configured plugins
    #

    log.info("Initializing descriptor index")
    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(config['plugins']['descriptor_set'],
                                      DescriptorSet.get_impls())

    log.info("Initializing classification factory")
    c_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])

    log.info("Initializing classifier")
    #: :type: smqtk.algorithms.Classifier
    classifier = from_config_dict(config['plugins']['classifier'],
                                  Classifier.get_impls())

    #
    # Setup/Process
    #
    def iter_uuids():
        with open(uuids_list_filepath) as f:
            for l in f:
                yield l.strip()

    def descr_for_uuid(uuid):
        """
        :type uuid: collections.Hashable
        :rtype: smqtk.representation.DescriptorElement
        """
        return descriptor_set.get_descriptor(uuid)

    def classify_descr(d):
        """
        :type d: smqtk.representation.DescriptorElement
        :rtype: smqtk.representation.ClassificationElement
        """
        return classifier.classify_one_element(d, c_factory,
                                               classify_overwrite)

    log.info("Initializing uuid-to-descriptor parallel map")
    #: :type: collections.Iterable[smqtk.representation.DescriptorElement]
    element_iter = parallel.parallel_map(
        descr_for_uuid,
        iter_uuids(),
        use_multiprocessing=p_use_multiprocessing,
        cores=p_index_extraction_cores,
        name="descr_for_uuid",
    )

    log.info("Initializing descriptor-to-classification parallel map")
    #: :type: collections.Iterable[smqtk.representation.ClassificationElement]
    classification_iter = parallel.parallel_map(
        classify_descr,
        element_iter,
        use_multiprocessing=p_use_multiprocessing,
        cores=p_classification_cores,
        name='classify_descr',
    )

    #
    # Write/Output files
    #

    c_labels = classifier.get_labels()

    def make_row(e):
        """
        :type e: smqtk.representation.ClassificationElement
        """
        c_m = e.get_classification()
        return [e.uuid] + [c_m[l] for l in c_labels]

    # column labels file
    log.info("Writing CSV column header file: %s", output_csv_header_filepath)
    safe_create_dir(os.path.dirname(output_csv_header_filepath))
    with open(output_csv_header_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        w.writerow(['uuid'] + [str(cl) for cl in c_labels])

    # CSV file
    log.info("Writing CSV data file: %s", output_csv_filepath)
    safe_create_dir(os.path.dirname(output_csv_filepath))
    pr = cli.ProgressReporter(log.info, 1.0)
    pr.start()
    with open(output_csv_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        for c in classification_iter:
            w.writerow(make_row(c))
            pr.increment_report()
        pr.report()

    log.info("Done")
Beispiel #4
0
def classify_files(config, label, file_globs):
    log = logging.getLogger(__name__)

    #: :type: smqtk.algorithms.Classifier
    classifier = \
        from_config_dict(config['classifier'], Classifier.get_impls())

    def log_avaialable_labels():
        log.info("Available classifier labels:")
        for l in classifier.get_labels():
            log.info("- %s", l)

    if label is None:
        log_avaialable_labels()
        return
    elif label not in classifier.get_labels():
        log.error(
            "Invalid classification label provided to compute and filter "
            "on: '%s'", label)
        log_avaialable_labels()
        return

    log.info("Collecting files from globs")
    #: :type: list[DataFileElement]
    data_elements = []
    uuid2filepath = {}
    for g in file_globs:
        if os.path.isfile(g):
            d = DataFileElement(g)
            data_elements.append(d)
            uuid2filepath[d.uuid()] = g
        else:
            log.debug("expanding glob: %s", g)
            for fp in glob.iglob(g):
                d = DataFileElement(fp)
                data_elements.append(d)
                uuid2filepath[d.uuid()] = fp
    if not data_elements:
        raise RuntimeError("No files provided for classification.")

    log.info("Computing descriptors")
    descriptor_factory = \
        DescriptorElementFactory.from_config(config['descriptor_factory'])
    #: :type: smqtk.algorithms.DescriptorGenerator
    descriptor_generator = \
        from_config_dict(config['descriptor_generator'],
                         DescriptorGenerator.get_impls())
    descr_map = descriptor_generator\
        .compute_descriptor_async(data_elements, descriptor_factory)

    log.info("Classifying descriptors")
    classification_factory = ClassificationElementFactory \
        .from_config(config['classification_factory'])
    classification_map = classifier\
        .classify_async(list(descr_map.values()), classification_factory)

    log.info("Printing input file paths that classified as the given label.")
    # map of UUID to filepath:
    uuid2c = dict((c.uuid, c) for c in six.itervalues(classification_map))
    for data in data_elements:
        d_uuid = data.uuid()
        log.debug("'{}' classification map: {}".format(
            uuid2filepath[d_uuid], uuid2c[d_uuid].get_classification()))
        if uuid2c[d_uuid].max_label() == label:
            print(uuid2filepath[d_uuid])