예제 #1
0
def test_compute_many_descriptors_batched(data_elements, descr_generator,
                                          mock_de, descr_factory, descr_index):
    """
    Test that compute_many_descriptors returns the correct number of
    elements in the correct order and calls wrapped methods the correct number
    of times when using an explicit batch size
    """
    batch_size = 2
    descriptors = compute_many_descriptors(data_elements,
                                           descr_generator,
                                           descr_factory,
                                           descr_index,
                                           batch_size=batch_size)

    descriptors_count = 0
    for desc, uuid in zip(descriptors, range(NUM_BASE_ELEMENTS)):
        # Make sure order is preserved
        assert desc[0].uuid() == uuid
        descriptors_count += 1
    # Make sure correct number of elements returned
    assert descriptors_count == NUM_BASE_ELEMENTS

    # Check number of calls
    num_calls = NUM_BASE_ELEMENTS // batch_size + [0, 1][bool(
        NUM_BASE_ELEMENTS % batch_size)]
    assert descr_generator.compute_descriptor_async.call_count == num_calls
    assert descr_index.add_many_descriptors.call_count == num_calls
예제 #2
0
파일: tasks.py 프로젝트: spongezhang/SMQTK
def compute_descriptors(task, folderId, dataElementUris, **kwargs):
    """
    Celery task for computing descriptors for a series of data element URIs
    belonging to a single folder.

    After computing descriptors for a series of Girder files, the relevant items
    are updated within Girder to contain the smqtk_uuid (sha1) value as metadata.

    :param task: Celery provided task object.
    :param folderId: The folder these images are related to, this is used for
        namespacing the descriptor index table.
    :param dataElementUris: A list of data element URIs, these are assumed to be
        GirderDataElement URIs.
    """
    task.job_manager.updateProgress(message='Computing descriptors',
                                    forceFlush=True)
    generator = CaffeDescriptorGenerator(
        girderUriFromTask(
            task, getSetting(task.girder_client, 'caffe_network_prototxt')),
        girderUriFromTask(
            task, getSetting(task.girder_client, 'caffe_network_model')),
        girderUriFromTask(task,
                          getSetting(task.girder_client, 'caffe_image_mean')))

    factory = DescriptorElementFactory(
        PostgresDescriptorElement, {
            'db_name': getSetting(task.girder_client, 'db_name'),
            'db_host': getSetting(task.girder_client, 'db_host'),
            'db_user': getSetting(task.girder_client, 'db_user'),
            'db_pass': getSetting(task.girder_client, 'db_pass')
        })

    index = descriptorIndexFromFolderId(task.girder_client, folderId)

    valid_elements = iter_valid_elements([x[1] for x in dataElementUris],
                                         generator.valid_content_types())

    descriptors = compute_functions.compute_many_descriptors(valid_elements,
                                                             generator,
                                                             factory,
                                                             index,
                                                             use_mp=False)

    fileToItemId = dict([(y.split('/')[-1], x) for x, y in dataElementUris])

    for de, descriptor in descriptors:
        # TODO Catch errors that could occur here
        with task.girder_client.session():
            task.girder_client.addMetadataToItem(
                fileToItemId[de.file_id], {'smqtk_uuid': descriptor.uuid()})
예제 #3
0
파일: tasks.py 프로젝트: Kitware/SMQTK
def compute_descriptors(task, folderId, dataElementUris, **kwargs):
    """
    Celery task for computing descriptors for a series of data element URIs
    belonging to a single folder.

    After computing descriptors for a series of Girder files, the relevant items
    are updated within Girder to contain the smqtk_uuid (sha1) value as metadata.

    :param task: Celery provided task object.
    :param folderId: The folder these images are related to, this is used for
        namespacing the descriptor index table.
    :param dataElementUris: A list of data element URIs, these are assumed to be
        GirderDataElement URIs.
    """
    task.job_manager.updateProgress(message='Computing descriptors', forceFlush=True)
    generator = CaffeDescriptorGenerator(
        girderUriFromTask(task, getSetting(task.girder_client, 'caffe_network_prototxt')),
        girderUriFromTask(task, getSetting(task.girder_client, 'caffe_network_model')),
        girderUriFromTask(task, getSetting(task.girder_client, 'caffe_image_mean')))

    factory = DescriptorElementFactory(PostgresDescriptorElement, {
        'db_name': getSetting(task.girder_client, 'db_name'),
        'db_host': getSetting(task.girder_client, 'db_host'),
        'db_user': getSetting(task.girder_client, 'db_user'),
        'db_pass': getSetting(task.girder_client, 'db_pass')
    })

    index = descriptorIndexFromFolderId(task.girder_client, folderId)

    valid_elements = iter_valid_elements([x[1] for x in dataElementUris], generator.valid_content_types())

    descriptors = compute_functions.compute_many_descriptors(valid_elements,
                                                             generator,
                                                             factory,
                                                             index,
                                                             use_mp=False)

    fileToItemId = dict([(y.split('/')[-1], x) for x, y in dataElementUris])

    for de, descriptor in descriptors:
        # TODO Catch errors that could occur here
        with task.girder_client.session():
            task.girder_client.addMetadataToItem(fileToItemId[de.file_id], {
                    'smqtk_uuid': descriptor.uuid()
            })
예제 #4
0
def test_compute_many_descriptors(data_elements, descr_generator, mock_de,
                                  descr_factory, descr_index):
    """
    Test that compute_many_descriptors returns the correct number of
    elements in the correct order and calls wrapped methods the correct number
    of times when an explicit batch size is not given
    """
    descriptors = compute_many_descriptors(data_elements,
                                           descr_generator,
                                           descr_factory,
                                           descr_index,
                                           batch_size=None)

    descriptors_count = 0
    for desc, uuid in zip(descriptors, range(NUM_BASE_ELEMENTS)):
        # Make sure order is preserved
        assert desc[0].uuid() == uuid
        descriptors_count += 1
    # Make sure correct number of elements returned
    assert descriptors_count == NUM_BASE_ELEMENTS

    # Since batch_size is None, these should only be called once
    assert descr_generator.compute_descriptor_async.call_count == 1
    assert descr_index.add_many_descriptors.call_count == 1
예제 #5
0
            return None

    def iter_valid_elements():
        valid_files_filter = parallel.parallel_map(is_valid_element,
                                                   file_paths,
                                                   name="check-file-type",
                                                   use_multiprocessing=True)
        for dfe in valid_files_filter:
            if dfe is not None:
                yield dfe

    log.info("Computing descriptors")
    m = compute_many_descriptors(
        iter_valid_elements(),
        generator,
        factory,
        descriptor_index,
        batch_size=batch_size,
    )

    # Recording computed file paths and associated file UUIDs (SHA1)
    cf = open(checkpoint_filepath, 'w')
    try:
        rps = [0] * 7
        for fp, descr in m:
            cf.write("{:s},{:s}\n".format(fp, descr.uuid()))
            report_progress(log.debug, rps, 1.)
    finally:
        cf.close()

    log.info("Done")
예제 #6
0
def run_file_list(c,
                  filelist_filepath,
                  checkpoint_filepath,
                  batch_size=None,
                  check_image=False):
    """
    Top level function handling configuration and inputs/outputs.

    :param c: Configuration dictionary (JSON)
    :type c: dict

    :param filelist_filepath: Path to a text file that lists paths to data
        files, separated by new lines.
    :type filelist_filepath: str

    :param checkpoint_filepath: Output file to which we write input filepath to
        SHA1 (UUID) relationships.
    :type checkpoint_filepath:

    :param batch_size: Optional batch size (None default) of data elements to
        process / descriptors to compute at a time. This causes files and
        stores to be written to incrementally during processing instead of
        one single batch transaction at a time.
    :type batch_size:

    :param check_image: Enable checking image loading from file before queueing
        that file for processing. If the check fails, the file is skipped
        instead of a halting exception being raised.
    :type check_image: bool

    """
    log = logging.getLogger(__name__)

    file_paths = [line.strip() for line in open(filelist_filepath)]

    log.info("Making descriptor factory")
    factory = DescriptorElementFactory.from_config(c['descriptor_factory'])

    log.info("Making descriptor index")
    descriptor_set = cast(
        DescriptorSet,
        from_config_dict(c['descriptor_set'], DescriptorSet.get_impls()))

    # ``data_set`` added to within the ``iter_valid_elements`` function.
    data_set: Optional[DataSet] = None
    if c['optional_data_set']['type'] is None:
        log.info("Not saving loaded data elements to data set")
    else:
        log.info("Initializing data set to append to")
        data_set = cast(
            DataSet,
            from_config_dict(c['optional_data_set'], DataSet.get_impls()))

    log.info("Making descriptor generator '%s'",
             c['descriptor_generator']['type'])
    generator = cast(
        DescriptorGenerator,
        from_config_dict(c['descriptor_generator'],
                         DescriptorGenerator.get_impls()))

    def iter_valid_elements():
        def is_valid(file_path):
            e = DataFileElement(file_path)

            if is_valid_element(
                    e,
                    valid_content_types=generator.valid_content_types(),
                    check_image=check_image):
                return e
            else:
                return False

        data_elements: Deque[DataFileElement] = collections.deque()
        valid_files_filter = parallel.parallel_map(is_valid,
                                                   file_paths,
                                                   name="check-file-type",
                                                   use_multiprocessing=True)
        for dfe in valid_files_filter:
            if dfe:
                yield dfe
                if data_set is not None:
                    data_elements.append(dfe)
                    if batch_size and len(data_elements) == batch_size:
                        log.debug(
                            "Adding data element batch to set (size: %d)",
                            len(data_elements))
                        data_set.add_data(*data_elements)
                        data_elements.clear()
        # elements only collected if we have a data-set configured, so add any
        # still in the deque to the set
        if data_set is not None and data_elements:
            log.debug("Adding data elements to set (size: %d",
                      len(data_elements))
            data_set.add_data(*data_elements)

    log.info("Computing descriptors")
    m = compute_many_descriptors(
        iter_valid_elements(),
        generator,
        factory,
        descriptor_set,
        batch_size=batch_size,
    )

    # Recording computed file paths and associated file UUIDs (SHA1)
    cf = open(checkpoint_filepath, 'w')
    cf_writer = csv.writer(cf)
    try:
        pr = ProgressReporter(log.debug, 1.0).start()
        for de, descr in m:
            # We know that we are using DataFileElements going into the
            # compute_many_descriptors, so we can assume that's what comes out
            # of it as well.
            # noinspection PyProtectedMember
            cf_writer.writerow([de._filepath, descr.uuid()])
            pr.increment_report()
        pr.report()
    finally:
        del cf_writer
        cf.close()

    log.info("Done")
예제 #7
0
def run_file_list(c,
                  filelist_filepath,
                  checkpoint_filepath,
                  batch_size=None,
                  check_image=False):
    """
    Top level function handling configuration and inputs/outputs.

    :param c: Configuration dictionary (JSON)
    :type c: dict

    :param filelist_filepath: Path to a text file that lists paths to data
        files, separated by new lines.
    :type filelist_filepath: str

    :param checkpoint_filepath: Output file to which we write input filepath to
        SHA1 (UUID) relationships.
    :type checkpoint_filepath:

    :param batch_size: Optional batch size (None default) of data elements to
        process / descriptors to compute at a time. This causes files and
        stores to be written to incrementally during processing instead of
        one single batch transaction at a time.
    :type batch_size:

    :param check_image: Enable checking image loading from file before queueing
        that file for processing. If the check fails, the file is skipped
        instead of a halting exception being raised.
    :type check_image: bool

    """
    log = logging.getLogger(__name__)

    file_paths = [l.strip() for l in open(filelist_filepath)]

    log.info("Making descriptor factory")
    factory = DescriptorElementFactory.from_config(c['descriptor_factory'])

    log.info("Making descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(c['descriptor_index'],
                                                 get_descriptor_index_impls())

    data_set = None
    if c['optional_data_set']['type'] is None:
        log.info("Not saving loaded data elements to data set")
    else:
        log.info("Initializing data set to append to")
        #: :type: smqtk.representation.DataSet
        data_set = plugin.from_plugin_config(c['optional_data_set'],
                                             get_data_set_impls())

    log.info("Making descriptor generator '%s'",
             c['descriptor_generator']['type'])
    #: :type: smqtk.algorithms.DescriptorGenerator
    generator = plugin.from_plugin_config(c['descriptor_generator'],
                                          get_descriptor_generator_impls())

    def iter_valid_elements():
        def is_valid(file_path):
            dfe = DataFileElement(file_path)

            if is_valid_element(
                    dfe,
                    valid_content_types=generator.valid_content_types(),
                    check_image=check_image):
                return dfe
            else:
                return False

        data_elements = collections.deque()
        valid_files_filter = parallel.parallel_map(is_valid,
                                                   file_paths,
                                                   name="check-file-type",
                                                   use_multiprocessing=True)
        for dfe in valid_files_filter:
            if dfe:
                yield dfe
                if data_set is not None:
                    data_elements.append(dfe)
                    if batch_size and len(data_elements) == batch_size:
                        log.debug(
                            "Adding data element batch to set (size: %d)",
                            len(data_elements))
                        data_set.add_data(*data_elements)
                        data_elements.clear()
        # elements only collected if we have a data-set configured, so add any
        # still in the deque to the set
        if data_elements:
            log.debug("Adding data elements to set (size: %d",
                      len(data_elements))
            data_set.add_data(*data_elements)

    log.info("Computing descriptors")
    m = compute_many_descriptors(
        iter_valid_elements(),
        generator,
        factory,
        descriptor_index,
        batch_size=batch_size,
    )

    # Recording computed file paths and associated file UUIDs (SHA1)
    cf = open(checkpoint_filepath, 'w')
    cf_writer = csv.writer(cf)
    try:
        rps = [0] * 7
        for fp, descr in m:
            cf_writer.writerow([fp, descr.uuid()])
            report_progress(log.debug, rps, 1.)
    finally:
        del cf_writer
        cf.close()

    log.info("Done")
예제 #8
0
                      fp, ct)
            return None

    def iter_valid_elements():
        valid_files_filter = parallel.parallel_map(is_valid_element,
                                                   file_paths,
                                                   name="check-file-type",
                                                   use_multiprocessing=True)
        for dfe in valid_files_filter:
            if dfe is not None:
                yield dfe

    log.info("Computing descriptors")
    m = compute_many_descriptors(iter_valid_elements(),
                                 generator,
                                 factory,
                                 descriptor_index,
                                 batch_size=batch_size,
                                 )

    # Recording computed file paths and associated file UUIDs (SHA1)
    cf = open(checkpoint_filepath, 'w')
    try:
        rps = [0] * 7
        for fp, descr in m:
            cf.write("{:s},{:s}\n".format(
                fp, descr.uuid()
            ))
            report_progress(log.debug, rps, 1.)
        # Final report
        rps[1] -= 1
        report_progress(log.debug, rps, 0.)