예제 #1
0
    def from_config(cls, config, parent_app):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        :param config: JSON compliant dictionary encapsulating
            a configuration.
        :type config: dict

        :param parent_app: Parent containing flask app instance
        :type parent_app: smqtk.web.search_app.app.search_app

        :return: Constructed instance from the provided config.
        :rtype: IqrSearch

        """
        merged = cls.get_default_config()
        merged.update(config)

        # construct nested objects via configurations
        merged['data_set'] = \
            plugin.from_plugin_config(merged['data_set'],
                                      get_data_set_impls())
        merged['descr_generator'] = \
            plugin.from_plugin_config(merged['descr_generator'],
                                      get_descriptor_generator_impls())
        merged['nn_index'] = \
            plugin.from_plugin_config(merged['nn_index'],
                                      get_nn_index_impls())

        merged['descriptor_factory'] = \
            DescriptorElementFactory.from_config(merged['descriptor_factory'])

        return cls(parent_app, **merged)
예제 #2
0
파일: iqr_search.py 프로젝트: dhandeo/SMQTK
    def from_config(cls, config, parent_app):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        :param config: JSON compliant dictionary encapsulating
            a configuration.
        :type config: dict

        :param parent_app: Parent containing flask app instance
        :type parent_app: smqtk.web.search_app.app.search_app

        :return: Constructed instance from the provided config.
        :rtype: IqrSearch

        """
        merged = cls.get_default_config()
        merged.update(config)

        # construct nested objects via configurations
        merged['data_set'] = \
            plugin.from_plugin_config(merged['data_set'],
                                      get_data_set_impls())
        merged['descr_generator'] = \
            plugin.from_plugin_config(merged['descr_generator'],
                                      get_descriptor_generator_impls())
        merged['nn_index'] = \
            plugin.from_plugin_config(merged['nn_index'],
                                      get_nn_index_impls())

        merged['descriptor_factory'] = \
            DescriptorElementFactory.from_config(merged['descriptor_factory'])

        return cls(parent_app, **merged)
예제 #3
0
def run_file_list(c,
                  filelist_filepath,
                  checkpoint_filepath,
                  batch_size=None,
                  check_image=False):
    """
    Top level function handling configuration and inputs/outputs.

    :param c: Configuration dictionary (JSON)
    :type c: dict

    :param filelist_filepath: Path to a text file that lists paths to data
        files, separated by new lines.
    :type filelist_filepath: str

    :param checkpoint_filepath: Output file to which we write input filepath to
        SHA1 (UUID) relationships.
    :type checkpoint_filepath:

    :param batch_size: Optional batch size (None default) of data elements to
        process / descriptors to compute at a time. This causes files and
        stores to be written to incrementally during processing instead of
        one single batch transaction at a time.
    :type batch_size:

    :param check_image: Enable checking image loading from file before queueing
        that file for processing. If the check fails, the file is skipped
        instead of a halting exception being raised.
    :type check_image: bool

    """
    log = logging.getLogger(__name__)

    file_paths = [l.strip() for l in open(filelist_filepath)]

    log.info("Making descriptor factory")
    factory = DescriptorElementFactory.from_config(c['descriptor_factory'])

    log.info("Making descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(c['descriptor_index'],
                                                 get_descriptor_index_impls())

    log.info("Making descriptor generator '%s'",
             c['descriptor_generator']['type'])
    #: :type: smqtk.algorithms.DescriptorGenerator
    generator = plugin.from_plugin_config(c['descriptor_generator'],
                                          get_descriptor_generator_impls())

    def test_image_load(dfe):
        try:
            PIL.Image.open(io.BytesIO(dfe.get_bytes()))
            return True
        except IOError, ex:
            # noinspection PyProtectedMember
            log.warn(
                "Failed to convert '%s' bytes into an image "
                "(error: %s). Skipping", dfe._filepath, str(ex))
            return False
예제 #4
0
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        This method should not be called via super unless and instance of the
        class is desired.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: LSHNearestNeighborIndex

        """
        # Controlling merge here so we can control known comment stripping from
        # default config.
        if merge_default:
            merged = cls.get_default_config()
            merge_dict(merged, config_dict)
        else:
            merged = config_dict

        merged['lsh_functor'] = \
            plugin.from_plugin_config(merged['lsh_functor'],
                                      get_lsh_functor_impls())
        merged['descriptor_index'] = \
            plugin.from_plugin_config(merged['descriptor_index'],
                                      get_descriptor_index_impls())

        # Hash index may be None for a default at-query-time linear indexing
        if merged['hash_index'] and merged['hash_index']['type']:
            merged['hash_index'] = \
                plugin.from_plugin_config(merged['hash_index'],
                                          get_hash_index_impls())
        else:
            cls.get_logger().debug(
                "No HashIndex impl given. Passing ``None``.")
            merged['hash_index'] = None

        # remove possible comment added by default generator
        if 'hash_index_comment' in merged:
            del merged['hash_index_comment']

        merged['hash2uuids_kvstore'] = \
            plugin.from_plugin_config(merged['hash2uuids_kvstore'],
                                      get_key_value_store_impls())

        return super(LSHNearestNeighborIndex, cls).from_config(merged, False)
예제 #5
0
파일: faiss.py 프로젝트: spongezhang/SMQTK
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        This method should not be called via super unless and instance of the
        class is desired.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: LSHNearestNeighborIndex

        """
        if merge_default:
            cfg = cls.get_default_config()
            merge_dict(cfg, config_dict)
        else:
            cfg = config_dict

        cfg['descriptor_set'] = plugin.from_plugin_config(
            cfg['descriptor_set'], get_descriptor_index_impls()
        )
        cfg['uid2idx_kvs'] = plugin.from_plugin_config(
            cfg['uid2idx_kvs'], get_key_value_store_impls()
        )
        cfg['idx2uid_kvs'] = plugin.from_plugin_config(
            cfg['idx2uid_kvs'], get_key_value_store_impls()
        )

        if (cfg['index_element'] and
                cfg['index_element']['type']):
            index_element = plugin.from_plugin_config(
                cfg['index_element'], get_data_element_impls())
            cfg['index_element'] = index_element
        else:
            cfg['index_element'] = None

        if (cfg['index_param_element'] and
                cfg['index_param_element']['type']):
            index_param_element = plugin.from_plugin_config(
                cfg['index_param_element'], get_data_element_impls())
            cfg['index_param_element'] = index_param_element
        else:
            cfg['index_param_element'] = None

        return super(FaissNearestNeighborsIndex, cls).from_config(cfg, False)
예제 #6
0
파일: __init__.py 프로젝트: Kitware/SMQTK
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        This method should not be called via super unless and instance of the
        class is desired.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: LSHNearestNeighborIndex

        """
        # Controlling merge here so we can control known comment stripping from
        # default config.
        if merge_default:
            merged = cls.get_default_config()
            merge_dict(merged, config_dict)
        else:
            merged = config_dict

        merged['lsh_functor'] = \
            plugin.from_plugin_config(merged['lsh_functor'],
                                      get_lsh_functor_impls())
        merged['descriptor_index'] = \
            plugin.from_plugin_config(merged['descriptor_index'],
                                      get_descriptor_index_impls())

        # Hash index may be None for a default at-query-time linear indexing
        if merged['hash_index'] and merged['hash_index']['type']:
            merged['hash_index'] = \
                plugin.from_plugin_config(merged['hash_index'],
                                          get_hash_index_impls())
        else:
            cls.get_logger().debug("No HashIndex impl given. Passing ``None``.")
            merged['hash_index'] = None

        # remove possible comment added by default generator
        if 'hash_index_comment' in merged:
            del merged['hash_index_comment']

        merged['hash2uuids_kvstore'] = \
            plugin.from_plugin_config(merged['hash2uuids_kvstore'],
                                      get_key_value_store_impls())

        return super(LSHNearestNeighborIndex, cls).from_config(merged, False)
예제 #7
0
    def __init__(self, json_config):
        super(IqrService, self).__init__(json_config)
        sc_config = json_config['iqr_service']['session_control']

        # Initialize from config
        self.positive_seed_neighbors = sc_config['positive_seed_neighbors']
        self.classifier_config = \
            json_config['iqr_service']['plugins']['classifier_config']
        self.classification_factory = \
            ClassificationElementFactory.from_config(
                json_config['iqr_service']['plugins']['classification_factory']
            )

        #: :type: smqtk.representation.DescriptorIndex
        self.descriptor_index = plugin.from_plugin_config(
            json_config['iqr_service']['plugins']['descriptor_index'],
            get_descriptor_index_impls(),
        )
        #: :type: smqtk.algorithms.NearestNeighborsIndex
        self.neighbor_index = plugin.from_plugin_config(
            json_config['iqr_service']['plugins']['neighbor_index'],
            get_nn_index_impls(),
        )

        self.rel_index_config = \
            json_config['iqr_service']['plugins']['relevancy_index_config']

        # Record of trained classifiers for a session. Session classifier
        # modifications locked under the parent session's global lock.
        #: :type: dict[collections.Hashable, smqtk.algorithms.SupervisedClassifier | None]
        self.session_classifiers = {}
        # Control for knowing when a new classifier should be trained for a
        # session (True == train new classifier). Modification for specific
        # sessions under parent session's lock.
        #: :type: dict[collections.Hashable, bool]
        self.session_classifier_dirty = {}

        def session_expire_callback(session):
            """
            :type session: smqtk.iqr.IqrSession
            """
            with session:
                self._log.debug("Removing session %s classifier", session.uuid)
                del self.session_classifiers[session.uuid]
                del self.session_classifier_dirty[session.uuid]

        self.controller = iqr_controller.IqrController(
            sc_config['session_expiration']['enabled'],
            sc_config['session_expiration']['check_interval_seconds'],
            session_expire_callback)
        self.session_timeout = \
            sc_config['session_expiration']['session_timeout']
예제 #8
0
def run_file_list(c, filelist_filepath, checkpoint_filepath, batch_size=None,
                  check_image=False):
    """
    Top level function handling configuration and inputs/outputs.

    :param c: Configuration dictionary (JSON)
    :type c: dict

    :param filelist_filepath: Path to a text file that lists paths to data
        files, separated by new lines.
    :type filelist_filepath: str

    :param checkpoint_filepath: Output file to which we write input filepath to
        SHA1 (UUID) relationships.
    :type checkpoint_filepath:

    :param batch_size: Optional batch size (None default) of data elements to
        process / descriptors to compute at a time. This causes files and
        stores to be written to incrementally during processing instead of
        one single batch transaction at a time.
    :type batch_size:

    """
    log = logging.getLogger(__name__)

    file_paths = [l.strip() for l in open(filelist_filepath)]

    log.info("Making descriptor factory")
    factory = DescriptorElementFactory.from_config(c['descriptor_factory'])

    log.info("Making descriptor index")
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(c['descriptor_index'],
                                                 get_descriptor_index_impls())

    log.info("Making descriptor generator '%s'",
             c['descriptor_generator']['type'])
    #: :type: smqtk.algorithms.DescriptorGenerator
    generator = plugin.from_plugin_config(c['descriptor_generator'],
                                          get_descriptor_generator_impls())

    def test_image_load(dfe):
        try:
            PIL.Image.open(io.BytesIO(dfe.get_bytes()))
            return True
        except IOError, ex:
            # noinspection PyProtectedMember
            log.warn("Failed to convert '%s' bytes into an image "
                     "(error: %s). Skipping",
                     dfe._filepath, str(ex))
            return False
예제 #9
0
def main():
    # Print help and exit if no arguments were passed
    if len(sys.argv) == 1:
        get_cli_parser().print_help()
        sys.exit(1)

    args = get_cli_parser().parse_args()
    config = utility_main_helper(get_default_config, args)

    log = logging.getLogger(__name__)
    log.debug('Showing debug messages.')

    #: :type: smqtk.representation.DescriptorIndex
    descriptor_set = plugin.from_plugin_config(
        config['plugins']['descriptor_set'], get_descriptor_index_impls()
    )
    #: :type: smqtk.algorithms.NearestNeighborsIndex
    nearest_neighbor_index = plugin.from_plugin_config(
        config['plugins']['nn_index'], get_nn_index_impls()
    )

    # noinspection PyShadowingNames
    def nearest_neighbors(descriptor, n):
        if n == 0:
            n = len(nearest_neighbor_index)

        uuids, descriptors = nearest_neighbor_index.nn(descriptor, n)
        # Strip first result (itself) and create list of (uuid, distance)
        return list(zip([x.uuid() for x in uuids[1:]], descriptors[1:]))

    if args.uuid_list is not None and not os.path.exists(args.uuid_list):
        log.error('Invalid file list path: %s', args.uuid_list)
        exit(103)
    elif args.num < 0:
        log.error('Number of nearest neighbors must be >= 0')
        exit(105)

    if args.uuid_list is not None:
        with open(args.uuid_list, 'r') as infile:
            for line in infile:
                descriptor = descriptor_set.get_descriptor(line.strip())
                print(descriptor.uuid())
                for neighbor in nearest_neighbors(descriptor, args.num):
                    print('%s,%f' % neighbor)
    else:
        for (uuid, descriptor) in descriptor_set.iteritems():
            print(uuid)
            for neighbor in nearest_neighbors(descriptor, args.num):
                print('%s,%f' % neighbor)
예제 #10
0
def main():
    description = """
    Utility script to transform a set of descriptors, specified by UUID, with
    matching class labels, to a test file usable by libSVM utilities for
    train/test experiments.

    The input CSV file is assumed to be of the format:

        uuid,label
        ...

    This is the same as the format requested for other scripts like
    ``classifier_model_validation.py``.

    This is very useful for searching for -c and -g parameter values for a
    training sample of data using the ``tools/grid.py`` script, found in the
    libSVM source tree. For example:

        <smqtk_source>/TPL/libsvm-3.1-custom/tools/grid.py \\
            -log2c -5,15,2 -log2c 3,-15,-2 -v 5 -out libsvm.grid.out \\
            -png libsvm.grid.png -t 0 -w1 3.46713615023 -w2 12.2613240418 \\
            output_of_this_script.txt
    """
    args, config = bin_utils.utility_main_helper(default_config, description,
                                                 extend_parser)
    log = logging.getLogger(__name__)

    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'], get_descriptor_index_impls())

    labels_filepath = args.f
    output_filepath = args.o

    # Run through labeled UUIDs in input file, getting the descriptor from the
    # configured index, applying the appropriate integer label and then writing
    # the formatted line out to the output file.
    input_uuid_labels = csv.reader(open(labels_filepath))

    with open(output_filepath, 'w') as ofile:
        label2int = {}
        next_int = 1
        uuids, labels = zip(*input_uuid_labels)

        log.info("Scanning input descriptors and labels")
        for i, (l, d) in enumerate(
                itertools.izip(labels,
                               descriptor_index.get_many_descriptors(uuids))):
            log.debug("%d %s", i, d.uuid())
            if l not in label2int:
                label2int[l] = next_int
                next_int += 1
            ofile.write("%d " % label2int[l] + ' '.join([
                "%d:%.12f" % (j + 1, f)
                for j, f in enumerate(d.vector()) if f != 0.0
            ]) + '\n')

    log.info("Integer label association:")
    for i, l in sorted((i, l) for l, i in label2int.iteritems()):
        log.info('\t%d :: %s', i, l)
예제 #11
0
def main():
    args = cli_parser().parse_args()
    config = utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    output_filepath = args.output_map
    if not output_filepath:
        raise ValueError("No path given for output map file (pickle).")

    #: :type: smqtk.representation.DescriptorIndex
    index = from_plugin_config(config['descriptor_index'],
                               get_descriptor_index_impls())
    mbkm = MiniBatchKMeans(verbose=args.verbose,
                           compute_labels=False,
                           **config['minibatch_kmeans_params'])
    initial_fit_size = int(config['initial_fit_size'])

    d_classes = mb_kmeans_build_apply(index, mbkm, initial_fit_size)

    log.info("Saving KMeans centroids to: %s",
             config['centroids_output_filepath_npy'])
    numpy.save(config['centroids_output_filepath_npy'], mbkm.cluster_centers_)

    log.info("Saving result classification map to: %s", output_filepath)
    safe_create_dir(os.path.dirname(output_filepath))
    with open(output_filepath, 'w') as f:
        cPickle.dump(d_classes, f, -1)

    log.info("Done")
예제 #12
0
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: KVSDataSet
        """
        if merge_default:
            config_dict = merge_dict(cls.get_default_config(), config_dict)

        # Convert KVStore config to instance for constructor.
        kvs_inst = plugin.from_plugin_config(config_dict['kvstore'],
                                             get_key_value_store_impls())
        config_dict['kvstore'] = kvs_inst

        return super(KVSDataSet, cls).from_config(config_dict, False)
예제 #13
0
파일: mrpt.py 프로젝트: spongezhang/SMQTK
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        This method should not be called via super unless and instance of the
        class is desired.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: MRPTNearestNeighborsIndex

        """
        if merge_default:
            cfg = cls.get_default_config()
            merge_dict(cfg, config_dict)
        else:
            cfg = config_dict

        cfg['descriptor_set'] = \
            plugin.from_plugin_config(cfg['descriptor_set'],
                                      get_descriptor_index_impls())

        return super(MRPTNearestNeighborsIndex, cls).from_config(cfg, False)
예제 #14
0
파일: linear.py 프로젝트: Kitware/SMQTK
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        This method should not be called via super unless an instance of the
        class is desired.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: LinearHashIndex

        """
        if merge_default:
            config_dict = merge_dict(cls.get_default_config(), config_dict)

        cache_element = None
        if config_dict['cache_element'] \
                and config_dict['cache_element']['type']:
            cache_element = \
                plugin.from_plugin_config(config_dict['cache_element'],
                                          get_data_element_impls())
        config_dict['cache_element'] = cache_element

        return super(LinearHashIndex, cls).from_config(config_dict, False)
예제 #15
0
    def _configure(self):
	# Test extracting config as dictionary
        self.config_dict = {}
        cfg = self.available_config()
        for it in cfg:
            self.config_dict[it] = self.config_value(it)

        # If we're in test mode, don't do anything that requires smqtk.
        if not apply_descriptor_test_mode:
            # create descriptor factory
            self.factory = DescriptorElementFactory(DescriptorMemoryElement, {})

            # get config file name
            file_name = self.config_value( "config_file" )

            # open file
            cfg_file = open( file_name )

            from smqtk.utils.jsmin import jsmin
            import json

            self.descr_config = json.loads( jsmin( cfg_file.read() ) )

            #self.generator = CaffeDescriptorGenerator.from_config(self.descr_config)
            self.generator = from_plugin_config(self.descr_config, get_descriptor_generator_impls)

        self._base_configure()
예제 #16
0
파일: memory.py 프로젝트: jbeezley/SMQTK
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: MemoryDescriptorIndex

        """
        if merge_default:
            config_dict = merge_dict(cls.get_default_config(), config_dict)

        # Optionally construct cache element from sub-config.
        if config_dict['cache_element'] \
                and config_dict['cache_element']['type']:
            e = plugin.from_plugin_config(config_dict['cache_element'],
                                          get_data_element_impls())
            config_dict['cache_element'] = e
        else:
            config_dict['cache_element'] = None

        return super(MemoryDescriptorIndex,
                     cls).from_config(config_dict, False)
예제 #17
0
    def from_config(cls, c, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        This method should not be called via super unless an instance of the
        class is desired.

        :param c: JSON compliant dictionary encapsulating
            a configuration.
        :type c: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: DataMemorySet

        """
        if merge_default:
            c = merge_dict(cls.get_default_config(), c)

        cache_element = None
        if c['cache_element'] and c['cache_element']['type']:
            cache_element = plugin.from_plugin_config(c['cache_element'],
                                                      get_data_element_impls())
        c['cache_element'] = cache_element

        return super(DataMemorySet, cls).from_config(c, False)
예제 #18
0
    def _configure(self):
	# Test extracting config as dictionary
        self.config_dict = {}
        cfg = self.available_config()
        for it in cfg:
            self.config_dict[it] = self.config_value(it)

        # If we're in test mode, don't do anything that requires smqtk.
        if not apply_descriptor_test_mode:
            # create descriptor factory
            self.factory = DescriptorElementFactory(DescriptorMemoryElement, {})

            # get config file name
            file_name = self.config_value( "config_file" )

            # open file
            cfg_file = open( file_name )

            from smqtk.utils.jsmin import jsmin
            import json

            self.descr_config = json.loads( jsmin( cfg_file.read() ) )

            #self.generator = CaffeDescriptorGenerator.from_config(self.descr_config)
            self.generator = from_plugin_config(self.descr_config, get_descriptor_generator_impls)

        self._base_configure()
예제 #19
0
def main():
    args = cli_parser().parse_args()
    config = utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    output_filepath = args.output_map
    if not output_filepath:
        raise ValueError("No path given for output map file (pickle).")

    #: :type: smqtk.representation.DescriptorIndex
    index = from_plugin_config(config['descriptor_index'],
                               get_descriptor_index_impls())
    mbkm = MiniBatchKMeans(verbose=args.verbose,
                           compute_labels=False,
                           **config['minibatch_kmeans_params'])
    initial_fit_size = int(config['initial_fit_size'])

    d_classes = mb_kmeans_build_apply(index, mbkm, initial_fit_size)

    log.info("Saving KMeans centroids to: %s",
             config['centroids_output_filepath_npy'])
    numpy.save(config['centroids_output_filepath_npy'], mbkm.cluster_centers_)

    log.info("Saving result classification map to: %s", output_filepath)
    safe_create_dir(os.path.dirname(output_filepath))
    with open(output_filepath, 'w') as f:
        cPickle.dump(d_classes, f, -1)

    log.info("Done")
예제 #20
0
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: MemoryKeyValueStore

        """
        # Copy top-level of config in order to not modify input instance.
        c = config_dict.copy()
        # Simplify specification for "no cache element"
        if 'cache_element' not in c or \
                c['cache_element'] is None or \
                c['cache_element']['type'] is None:
            c['cache_element'] = None
        else:
            # Create from nested config.
            c['cache_element'] = \
                from_plugin_config(config_dict['cache_element'],
                                   get_data_element_impls())
        return super(MemoryKeyValueStore, cls).from_config(c)
예제 #21
0
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: KVSDataSet
        """
        if merge_default:
            config_dict = merge_dict(cls.get_default_config(), config_dict)

        # Convert KVStore config to instance for constructor.
        kvs_inst = plugin.from_plugin_config(config_dict['kvstore'],
                                             get_key_value_store_impls())
        config_dict['kvstore'] = kvs_inst

        return super(KVSDataSet, cls).from_config(config_dict, False)
예제 #22
0
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        This method should not be called via super unless an instance of the
        class is desired.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: SkLearnBallTreeHashIndex

        """
        if merge_default:
            config_dict = merge_dict(cls.get_default_config(), config_dict)

        # Parse ``cache_element`` configuration if set.
        cache_element = None
        if config_dict['cache_element'] and \
                config_dict['cache_element']['type']:
            cache_element = \
                plugin.from_plugin_config(config_dict['cache_element'],
                                          get_data_element_impls())
        config_dict['cache_element'] = cache_element

        return super(SkLearnBallTreeHashIndex,
                     cls).from_config(config_dict, False)
예제 #23
0
파일: itq.py 프로젝트: msarahan/SMQTK
    def from_config(cls, config_dict):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary.

        This implementation nests the configuration of the CodeIndex
        implementation to use. If there is a ``code_index`` in the configuration
        dictionary, it should be a nested plugin specification dictionary, as
        specified by the ``smqtk.utils.plugin.from_config`` method.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :return: ITQ similarity index instance
        :rtype: ITQNearestNeighborsIndex

        """
        merged = cls.get_default_config()
        merged.update(config_dict)

        # Transform nested plugin stuff into actual classes if provided.
        merged['code_index'] = \
            plugin.from_plugin_config(merged['code_index'],
                                      get_code_index_impls)

        return super(ITQNearestNeighborsIndex, cls).from_config(merged)
예제 #24
0
파일: train_itq.py 프로젝트: Kitware/SMQTK
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorIndex [type=%s]",
             config['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['descriptor_index'],
        get_descriptor_index_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):
        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()
        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_index.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorIndex (count=%d)",
                 len(descriptor_index))
        d_iter = descriptor_index

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
예제 #25
0
파일: memory.py 프로젝트: Kitware/SMQTK
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: MemoryKeyValueStore

        """
        # Copy top-level of config in order to not modify input instance.
        c = config_dict.copy()
        # Simplify specification for "no cache element"
        if 'cache_element' not in c or \
                c['cache_element'] is None or \
                c['cache_element']['type'] is None:
            c['cache_element'] = None
        else:
            # Create from nested config.
            c['cache_element'] = \
                from_plugin_config(config_dict['cache_element'],
                                   get_data_element_impls())
        return super(MemoryKeyValueStore, cls).from_config(c)
예제 #26
0
파일: memory.py 프로젝트: Kitware/SMQTK
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: MemoryDescriptorIndex

        """
        if merge_default:
            config_dict = merge_dict(cls.get_default_config(), config_dict)

        # Optionally construct cache element from sub-config.
        if config_dict['cache_element'] \
                and config_dict['cache_element']['type']:
            e = plugin.from_plugin_config(config_dict['cache_element'],
                                          get_data_element_impls())
            config_dict['cache_element'] = e
        else:
            config_dict['cache_element'] = None

        return super(MemoryDescriptorIndex, cls).from_config(config_dict, False)
예제 #27
0
파일: mrpt.py 프로젝트: Kitware/SMQTK
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        This method should not be called via super unless and instance of the
        class is desired.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: MRPTNearestNeighborsIndex

        """
        if merge_default:
            cfg = cls.get_default_config()
            merge_dict(cfg, config_dict)
        else:
            cfg = config_dict

        cfg['descriptor_set'] = \
            plugin.from_plugin_config(cfg['descriptor_set'],
                                      get_descriptor_index_impls())

        return super(MRPTNearestNeighborsIndex, cls).from_config(cfg, False)
예제 #28
0
def run_file_list(c, filelist_filepath, checkpoint_filepath):
    log = logging.getLogger(__name__)

    file_paths = [l.strip() for l in open(filelist_filepath)]

    log.info("Making memory factory")
    factory = DescriptorElementFactory.from_config(c['descriptor_factory'])

    log.info("Making descriptor generator '%s'",
             c['descriptor_generator']['type'])
    #: :type: smqtk.algorithms.DescriptorGenerator
    generator = from_plugin_config(c['descriptor_generator'],
                                   get_descriptor_generator_impls)
    log.info("Making descriptor generator -- Done")

    valid_file_paths = dict()
    invalid_file_paths = dict()

    def iter_valid_files():
        for fp in file_paths:
            dfe = DataFileElement(fp)
            ct = dfe.content_type()
            if ct in generator.valid_content_types():
                valid_file_paths[fp] = ct
                yield fp
            else:
                invalid_file_paths[fp] = ct

    log.info("Computing descriptors")
    m = compute_many_descriptors(iter_valid_files(),
                                 generator,
                                 factory,
                                 batch_size=256,
                                 )

    # Recording computed file paths and associated file UUIDs (SHA1)
    cf = open(checkpoint_filepath, 'a')
    try:
        for fp, descr in m:
            cf.write("{:s},{:s}\n".format(
                fp, descr.uuid()
            ))
            cf.flush()
    finally:
        cf.close()

    # Output valid file and invalid file dictionaries as pickle
    log.info("Writing valid filepaths map")
    with open('valid_file_map.pickle', 'wb') as f:
        cPickle.dump(valid_file_paths, f)
    log.info("Writing invalid filepaths map")
    with open('invalid_file_map.pickle', 'wb') as f:
        cPickle.dump(invalid_file_paths, f)

    log.info("Done")
예제 #29
0
    def update_working_index(self, nn_index):
        """
        Initialize or update our current working index using the given
        :class:`.NearestNeighborsIndex` instance given our current positively
        labeled descriptor elements.

        We only query from the index for new positive elements since the last
        update or reset.

        :param nn_index: :class:`.NearestNeighborsIndex` to query from.
        :type nn_index: smqtk.algorithms.NearestNeighborsIndex

        :raises RuntimeError: There are no positive example descriptors in this
            session to use as a basis for querying.

        """
        pos_examples = (self.external_positive_descriptors |
                        self.positive_descriptors)
        if len(pos_examples) == 0:
            raise RuntimeError("No positive descriptors to query the neighbor "
                               "index with.")

        # Not clearing working index because this step is intended to be
        # additive.
        updated = False

        # adding to working index
        self._log.info("Building working index using %d positive examples "
                       "(%d external, %d adjudicated)",
                       len(pos_examples),
                       len(self.external_positive_descriptors),
                       len(self.positive_descriptors))
        # TODO: parallel_map and reduce with merge-dict
        for p in pos_examples:
            if p.uuid() not in self._wi_seeds_used:
                self._log.debug("Querying neighbors to: %s", p)
                self.working_index.add_many_descriptors(
                    nn_index.nn(p, n=self.pos_seed_neighbors)[0]
                )
                self._wi_seeds_used.add(p.uuid())
                updated = True

        # Make new relevancy index
        if updated:
            self._log.info("Creating new relevancy index over working index.")
            #: :type: smqtk.algorithms.relevancy_index.RelevancyIndex
            self.rel_index = plugin.from_plugin_config(
                self.rel_index_config, get_relevancy_index_impls()
            )
            self.rel_index.build_index(self.working_index.iterdescriptors())
예제 #30
0
파일: __init__.py 프로젝트: mrG7/SMQTK
    def get_descriptor_inst(self, label):
        """
        Get the cached content descriptor instance for a configuration label
        :type label: str
        :rtype: smqtk.descriptor_generator.DescriptorGenerator
        """
        with self.descriptor_cache_lock:
            if label not in self.descriptor_cache:
                self.log.debug("Caching descriptor '%s'", label)
                self.descriptor_cache[label] = plugin.from_plugin_config(
                    self.generator_label_configs[label], get_descriptor_generator_impls
                )

            return self.descriptor_cache[label]
예제 #31
0
파일: __init__.py 프로젝트: dhandeo/SMQTK
    def __init__(self, json_config):
        super(IqrService, self).__init__(json_config)

        # Initialize from config
        self.positive_seed_neighbors = \
            json_config['iqr_service']['positive_seed_neighbors']
        self.classifier_config = \
            json_config['iqr_service']['plugins']['classifier_config']
        self.classification_factory = \
            ClassificationElementFactory.from_config(
                json_config['iqr_service']['plugins']['classification_factory']
            )

        #: :type: smqtk.representation.DescriptorIndex
        self.descriptor_index = plugin.from_plugin_config(
            json_config['iqr_service']['plugins']['descriptor_index'],
            get_descriptor_index_impls(),
        )
        #: :type: smqtk.algorithms.NearestNeighborsIndex
        self.neighbor_index = plugin.from_plugin_config(
            json_config['iqr_service']['plugins']['neighbor_index'],
            get_nn_index_impls(),
        )

        self.rel_index_config = \
            json_config['iqr_service']['plugins']['relevancy_index_config']

        self.controller = iqr_controller.IqrController()
        # Record of trained classifiers for a session. Session classifier
        # modifications locked under the parent session's global lock.
        #: :type: dict[collections.Hashable, smqtk.algorithms.SupervisedClassifier | None]
        self.session_classifiers = {}
        # Control for knowing when a new classifier should be trained for a
        # session (True == train new classifier). Modification for specific
        # sessions under parent session's lock.
        #: :type: dict[collections.Hashable, bool]
        self.session_classifier_dirty = {}
예제 #32
0
파일: itq.py 프로젝트: Kitware/SMQTK
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the JSON-compliant
        configuration dictionary encapsulating initialization arguments.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: ItqFunctor

        """
        if merge_default:
            config_dict = merge_dict(cls.get_default_config(), config_dict)

        data_element_impls = get_data_element_impls()
        # Mean vector cache element.
        mean_vec_cache = None
        if config_dict['mean_vec_cache'] and \
                config_dict['mean_vec_cache']['type']:
            mean_vec_cache = plugin.from_plugin_config(
                config_dict['mean_vec_cache'], data_element_impls)
        config_dict['mean_vec_cache'] = mean_vec_cache
        # Rotation matrix cache element.
        rotation_cache = None
        if config_dict['rotation_cache'] and \
                config_dict['rotation_cache']['type']:
            rotation_cache = plugin.from_plugin_config(
                config_dict['rotation_cache'], data_element_impls)
        config_dict['rotation_cache'] = rotation_cache

        return super(ItqFunctor, cls).from_config(config_dict, False)
예제 #33
0
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the JSON-compliant
        configuration dictionary encapsulating initialization arguments.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: ItqFunctor

        """
        if merge_default:
            config_dict = merge_dict(cls.get_default_config(), config_dict)

        data_element_impls = get_data_element_impls()
        # Mean vector cache element.
        mean_vec_cache = None
        if config_dict['mean_vec_cache'] and \
                config_dict['mean_vec_cache']['type']:
            mean_vec_cache = plugin.from_plugin_config(
                config_dict['mean_vec_cache'], data_element_impls)
        config_dict['mean_vec_cache'] = mean_vec_cache
        # Rotation matrix cache element.
        rotation_cache = None
        if config_dict['rotation_cache'] and \
                config_dict['rotation_cache']['type']:
            rotation_cache = plugin.from_plugin_config(
                config_dict['rotation_cache'], data_element_impls)
        config_dict['rotation_cache'] = rotation_cache

        return super(ItqFunctor, cls).from_config(config_dict, False)
예제 #34
0
    def __init__(self, json_config):
        super(IqrService, self).__init__(json_config)

        # Initialize from config
        self.positive_seed_neighbors = \
            json_config['iqr_service']['positive_seed_neighbors']
        self.classifier_config = \
            json_config['iqr_service']['plugins']['classifier_config']
        self.classification_factory = \
            ClassificationElementFactory.from_config(
                json_config['iqr_service']['plugins']['classification_factory']
            )

        #: :type: smqtk.representation.DescriptorIndex
        self.descriptor_index = plugin.from_plugin_config(
            json_config['iqr_service']['plugins']['descriptor_index'],
            get_descriptor_index_impls(),
        )
        #: :type: smqtk.algorithms.NearestNeighborsIndex
        self.neighbor_index = plugin.from_plugin_config(
            json_config['iqr_service']['plugins']['neighbor_index'],
            get_nn_index_impls(),
        )

        self.rel_index_config = \
            json_config['iqr_service']['plugins']['relevancy_index_config']

        self.controller = iqr_controller.IqrController()
        # Record of trained classifiers for a session. Session classifier
        # modifications locked under the parent session's global lock.
        #: :type: dict[collections.Hashable, smqtk.algorithms.SupervisedClassifier | None]
        self.session_classifiers = {}
        # Control for knowing when a new classifier should be trained for a
        # session (True == train new classifier). Modification for specific
        # sessions under parent session's lock.
        #: :type: dict[collections.Hashable, bool]
        self.session_classifier_dirty = {}
예제 #35
0
    def update_working_index(self, nn_index):
        """
        Initialize or update our current working index using the given
        :class:`.NearestNeighborsIndex` instance given our current positively
        labeled descriptor elements.

        We only query from the index for new positive elements since the last
        update or reset.

        :param nn_index: :class:`.NearestNeighborsIndex` to query from.
        :type nn_index: smqtk.algorithms.NearestNeighborsIndex

        :raises RuntimeError: There are no positive example descriptors in this
            session to use as a basis for querying.

        """
        pos_examples = (self.external_positive_descriptors
                        | self.positive_descriptors)
        if len(pos_examples) == 0:
            raise RuntimeError("No positive descriptors to query the neighbor "
                               "index with.")

        # Not clearing working index because this step is intended to be
        # additive.
        updated = False

        # adding to working index
        self._log.info(
            "Building working index using %d positive examples "
            "(%d external, %d adjudicated)", len(pos_examples),
            len(self.external_positive_descriptors),
            len(self.positive_descriptors))
        # TODO: parallel_map and reduce with merge-dict
        for p in pos_examples:
            if p.uuid() not in self._wi_seeds_used:
                self._log.debug("Querying neighbors to: %s", p)
                self.working_index.add_many_descriptors(
                    nn_index.nn(p, n=self.pos_seed_neighbors)[0])
                self._wi_seeds_used.add(p.uuid())
                updated = True

        # Make new relevancy index
        if updated:
            self._log.info("Creating new relevancy index over working index.")
            #: :type: smqtk.algorithms.relevancy_index.RelevancyIndex
            self.rel_index = plugin.from_plugin_config(
                self.rel_index_config, get_relevancy_index_impls())
            self.rel_index.build_index(self.working_index.iterdescriptors())
예제 #36
0
    def get_descriptor_inst(self, label):
        """
        Get the cached content descriptor instance for a configuration label
        :type label: str
        :rtype: smqtk.descriptor_generator.DescriptorGenerator
        """
        with self.descriptor_cache_lock:
            if label not in self.descriptor_cache:
                self.log.debug("Caching descriptor '%s'", label)
                self.descriptor_cache[label] = \
                    plugin.from_plugin_config(
                    self.generator_label_configs[label],
                        get_descriptor_generator_impls
                    )

            return self.descriptor_cache[label]
예제 #37
0
def run_file_list(c, filelist_filepath, checkpoint_filepath, batch_size):
    log = logging.getLogger(__name__)

    file_paths = [l.strip() for l in open(filelist_filepath)]

    log.info("Making descriptor factory")
    factory = DescriptorElementFactory.from_config(c["descriptor_factory"])

    log.info("Making descriptor generator '%s'", c["descriptor_generator"]["type"])
    #: :type: smqtk.algorithms.DescriptorGenerator
    generator = from_plugin_config(c["descriptor_generator"], get_descriptor_generator_impls)
    log.info("Making descriptor generator -- Done")

    valid_file_paths = dict()
    invalid_file_paths = dict()

    def iter_valid_elements():
        for fp in file_paths:
            dfe = DataFileElement(fp)
            ct = dfe.content_type()
            if ct in generator.valid_content_types():
                valid_file_paths[fp] = ct
                yield dfe
            else:
                invalid_file_paths[fp] = ct

    log.info("Computing descriptors")
    m = compute_many_descriptors(iter_valid_elements(), generator, factory, batch_size=batch_size)

    # Recording computed file paths and associated file UUIDs (SHA1)
    cf = open(checkpoint_filepath, "a")
    try:
        for fp, descr in m:
            cf.write("{:s},{:s}\n".format(fp, descr.uuid()))
            cf.flush()
    finally:
        cf.close()

    # Output valid file and invalid file dictionaries as pickle
    log.info("Writing valid filepaths map")
    with open("file_map.valid.pickle", "wb") as f:
        cPickle.dump(valid_file_paths, f)
    log.info("Writing invalid filepaths map")
    with open("file_map.invalid.pickle", "wb") as f:
        cPickle.dump(invalid_file_paths, f)

    log.info("Done")
예제 #38
0
def main():
    description = """
    Tool for training the ITQ functor algorithm's model on descriptors in an
    index.

    By default, we use all descriptors in the configured index
    (``uuids_list_filepath`` is not given a value).

    The ``uuids_list_filepath`` configuration property is optional and should
    be used to specify a sub-set of descriptors in the configured index to
    train on. This only works if the stored descriptors' UUID is a type of
    string.
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorIndex [type=%s]",
             config['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['descriptor_index'],
        get_descriptor_index_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):

        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()

        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_index.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorIndex (count=%d)",
                 len(descriptor_index))
        d_iter = descriptor_index

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
예제 #39
0
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    api_root = config['tool']['girder_api_root']
    api_key = config['tool']['api_key']
    api_query_batch = config['tool']['api_query_batch']
    insert_batch_size = config['tool']['dataset_insert_batch_size']

    # Collect N folder/item/file references on CL and any files referenced.
    #: :type: list[str]
    ids_folder = args.folder
    #: :type: list[str]
    ids_item = args.item
    #: :type: list[str]
    ids_file = args.file

    if args.folder_list:
        with open(args.folder_list) as f:
            ids_folder.extend([fid.strip() for fid in f])
    if args.item_list:
        with open(args.item_list) as f:
            ids_item.extend([iid.strip() for iid in f])
    if args.file_list:
        with open(args.file_list) as f:
            ids_file.extend([fid.strip() for fid in f])

    #: :type: smqtk.representation.DataSet
    data_set = plugin.from_plugin_config(config['plugins']['data_set'],
                                         get_data_set_impls())

    batch = collections.deque()
    rps = [0]*7
    for e in find_girder_files(api_root, ids_folder, ids_item, ids_file,
                               api_key, api_query_batch):
        batch.append(e)
        if insert_batch_size and len(batch) >= insert_batch_size:
            data_set.add_data(*batch)
            batch.clear()
        bin_utils.report_progress(log.info, rps, 1.0)

    if batch:
        data_set.add_data(*batch)

    log.info('Done')
예제 #40
0
    def test_from_config(self):
        test_config = {
            'type': 'DummyAlgo1',
            'DummyAlgo1': {'foo': 256, 'bar': 'Some string value'},
            'DummyAlgo2': {
                'child': {'foo': -1, 'bar': 'some other value'},
                'alpha': 1.0,
                'beta': 'euclidean',
            },
            'notAnImpl': {}
        }

        #: :type: DummyAlgo1
        i = from_plugin_config(test_config, dummy_getter())
        self.assertIsInstance(i, DummyAlgo1)
        self.assertEqual(i.foo, 256)
        self.assertEqual(i.bar, 'Some string value')
예제 #41
0
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    api_root = config['tool']['girder_api_root']
    api_key = config['tool']['api_key']
    api_query_batch = config['tool']['api_query_batch']
    insert_batch_size = config['tool']['dataset_insert_batch_size']

    # Collect N folder/item/file references on CL and any files referenced.
    #: :type: list[str]
    ids_folder = args.folder
    #: :type: list[str]
    ids_item = args.item
    #: :type: list[str]
    ids_file = args.file

    if args.folder_list:
        with open(args.folder_list) as f:
            ids_folder.extend([fid.strip() for fid in f])
    if args.item_list:
        with open(args.item_list) as f:
            ids_item.extend([iid.strip() for iid in f])
    if args.file_list:
        with open(args.file_list) as f:
            ids_file.extend([fid.strip() for fid in f])

    #: :type: smqtk.representation.DataSet
    data_set = plugin.from_plugin_config(config['plugins']['data_set'],
                                         get_data_set_impls())

    batch = collections.deque()
    rps = [0] * 7
    for e in find_girder_files(api_root, ids_folder, ids_item, ids_file,
                               api_key, api_query_batch):
        batch.append(e)
        if insert_batch_size and len(batch) >= insert_batch_size:
            data_set.add_data(*batch)
            batch.clear()
        bin_utils.report_progress(log.info, rps, 1.0)

    if batch:
        data_set.add_data(*batch)

    log.info('Done')
예제 #42
0
파일: train_itq.py 프로젝트: dhandeo/SMQTK
def main():
    description = """
    Tool for training the ITQ functor algorithm's model on descriptors in an
    index.

    By default, we use all descriptors in the configured index
    (``uuids_list_filepath`` is not given a value).

    The ``uuids_list_filepath`` configuration property is optional and should
    be used to specify a sub-set of descriptors in the configured index to
    train on. This only works if the stored descriptors' UUID is a type of
    string.
    """
    args, config = bin_utils.utility_main_helper(default_config, description)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorIndex [type=%s]",
             config['descriptor_index']['type'])
    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['descriptor_index'],
        get_descriptor_index_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):
        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()
        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_index.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorIndex (count=%d)",
                 len(descriptor_index))
        d_iter = descriptor_index

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
예제 #43
0
def main():
    parser = cli_parser()
    args = parser.parse_args()

    bin_utils.initialize_logging(logging.getLogger(),
                                 logging.INFO - (10 * args.verbose))
    log = logging.getLogger("main")

    # Merge loaded config with default
    config_loaded = False
    config = default_config()
    if args.config:
        if osp.isfile(args.config):
            with open(args.config, 'r') as f:
                config.update(json.load(f))
            config_loaded = True
        elif not osp.isfile(args.config):
            log.error("Configuration file path not valid.")
            exit(1)

    # output configuration dictionary when asked for.
    bin_utils.output_config(args.output_config, config, log, True)

    if not config_loaded:
        log.error("No configuration provided")
        exit(1)

    log.debug("Script arguments:\n%s" % args)

    def iter_input_elements():
        for f in args.input_files:
            f = osp.expanduser(f)
            if osp.isfile(f):
                yield DataFileElement(f)
            else:
                log.debug("Expanding glob: %s" % f)
                for g in glob.glob(f):
                    yield DataFileElement(g)

    log.info("Adding elements to data set")
    #: :type: smqtk.representation.DataSet
    ds = plugin.from_plugin_config(config['data_set'], get_data_set_impls())
    ds.add_data(*iter_input_elements())
예제 #44
0
def main():
    parser = cli_parser()
    args = parser.parse_args()

    bin_utils.initialize_logging(logging.getLogger(),
                                 logging.INFO - (10 * args.verbose))
    log = logging.getLogger("main")

    # Merge loaded config with default
    config_loaded = False
    config = default_config()
    if args.config:
        if osp.isfile(args.config):
            with open(args.config, 'r') as f:
                config.update(json.load(f))
            config_loaded = True
        elif not osp.isfile(args.config):
            log.error("Configuration file path not valid.")
            exit(1)

    # output configuration dictionary when asked for.
    bin_utils.output_config(args.output_config, config, log, True)

    if not config_loaded:
        log.error("No configuration provided")
        exit(1)

    log.debug("Script arguments:\n%s" % args)

    def iter_input_elements():
        for f in args.input_files:
            f = osp.expanduser(f)
            if osp.isfile(f):
                yield DataFileElement(f)
            else:
                log.debug("Expanding glob: %s" % f)
                for g in glob.glob(f):
                    yield DataFileElement(g)

    log.info("Adding elements to data set")
    #: :type: smqtk.representation.DataSet
    ds = plugin.from_plugin_config(config['data_set'], get_data_set_impls())
    ds.add_data(*iter_input_elements())
예제 #45
0
def train_classifier_iqr(config, iqr_state_fp):
    #: :type: smqtk.algorithms.SupervisedClassifier
    classifier = from_plugin_config(
        config['classifier'],
        get_classifier_impls(sub_interface=SupervisedClassifier)
    )

    # Load state into an empty IqrSession instance.
    with open(iqr_state_fp, 'rb') as f:
        state_bytes = f.read().strip()
    descr_factory = DescriptorElementFactory(DescriptorMemoryElement, {})
    iqrs = IqrSession()
    iqrs.set_state_bytes(state_bytes, descr_factory)

    # Positive descriptor examples for training are composed of those from
    # external and internal sets. Same for negative descriptor examples.
    pos = iqrs.positive_descriptors | iqrs.external_positive_descriptors
    neg = iqrs.negative_descriptors | iqrs.external_negative_descriptors
    classifier.train(class_examples={'positive': pos, 'negative': neg})
예제 #46
0
def main():
    args = cli_parser().parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #: :type: smqtk.representation.DescriptorIndex
    descriptor_index = plugin.from_plugin_config(
        config['plugins']['descriptor_index'],
        get_descriptor_index_impls()
    )

    labels_filepath = args.f
    output_filepath = args.o

    # Run through labeled UUIDs in input file, getting the descriptor from the
    # configured index, applying the appropriate integer label and then writing
    # the formatted line out to the output file.
    input_uuid_labels = csv.reader(open(labels_filepath))

    with open(output_filepath, 'w') as ofile:
        label2int = {}
        next_int = 1
        uuids, labels = list(zip(*input_uuid_labels))

        log.info("Scanning input descriptors and labels")
        for i, (l, d) in enumerate(
                    zip(labels, descriptor_index.get_many_descriptors(uuids))):
            log.debug("%d %s", i, d.uuid())
            if l not in label2int:
                label2int[l] = next_int
                next_int += 1
            ofile.write(
                "%d " % label2int[l] +
                ' '.join(["%d:%.12f" % (j+1, f)
                          for j, f in enumerate(d.vector())
                          if f != 0.0]) +
                '\n'
            )

    log.info("Integer label association:")
    for i, l in sorted((i, l) for l, i in six.iteritems(label2int)):
        log.info('\t%d :: %s', i, l)
예제 #47
0
def train_classifier_iqr(config, iqr_state_fp):
    log = logging.getLogger(__name__)

    #: :type: smqtk.algorithms.SupervisedClassifier
    classifier = from_plugin_config(config['classifier'],
                                    get_classifier_impls())

    if not isinstance(classifier, SupervisedClassifier):
        raise RuntimeError("Configured classifier must be of the "
                           "SupervisedClassifier type in order to train.")

    # Get pos/neg descriptors out of iqr state zip
    z_file = open(iqr_state_fp, 'r')
    z = zipfile.ZipFile(z_file)
    if len(z.namelist()) != 1:
        raise RuntimeError("Invalid IqrState file!")
    iqrs = json.loads(z.read(z.namelist()[0]))
    if len(iqrs) != 2:
        raise RuntimeError("Invalid IqrState file!")
    if 'pos' not in iqrs or 'neg' not in iqrs:
        raise RuntimeError("Invalid IqrState file!")

    log.info("Loading pos/neg descriptors")
    #: :type: list[smqtk.representation.DescriptorElement]
    pos = []
    #: :type: list[smqtk.representation.DescriptorElement]
    neg = []
    i = 0
    for v in set(map(tuple, iqrs['pos'])):
        d = DescriptorMemoryElement('train', i)
        d.set_vector(numpy.array(v))
        pos.append(d)
        i += 1
    for v in set(map(tuple, iqrs['neg'])):
        d = DescriptorMemoryElement('train', i)
        d.set_vector(numpy.array(v))
        neg.append(d)
        i += 1
    log.info('    positive -> %d', len(pos))
    log.info('    negative -> %d', len(neg))

    classifier.train(positive=pos, negative=neg)
예제 #48
0
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        This method should not be called via super unless an instance of the
        class is desired.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: ClassifierCollection

        """
        if merge_default:
            config_dict = merge_dict(cls.get_default_config(), config_dict)

        classifier_map = {}

        # Copying list of keys so we can update the dictionary as we loop.
        for label in list(config_dict.keys()):
            # Skip the example section.
            if label == cls.EXAMPLE_KEY:
                continue

            classifier_config = config_dict[label]
            classifier = plugin.from_plugin_config(classifier_config,
                                                   get_classifier_impls())
            classifier_map[label] = classifier

        # Don't merge back in "example" default
        return super(ClassifierCollection, cls).from_config(
            {'classifiers': classifier_map},
            merge_default=False
        )
예제 #49
0
    def initialize(self):
        """
        Initialize working index based on currently set positive exemplar data.

        This takes into account the currently set positive data descriptors as
        well as positively adjudicated descriptors from the lifetime of this
        session.

        :raises RuntimeError: There are no positive example descriptors in this
            session to use as a basis for querying.

        """
        if len(self.ex_pos_descriptors) + \
                len(self.positive_descriptors) <= 0:
            raise RuntimeError("No positive descriptors to query the neighbor "
                               "index with.")
        # Not clearing index because this step is intended to be additive

        # build up new working index
        # TODO: Only query using new positives since previous queries
        for p in self.ex_pos_descriptors.itervalues():
            if p.uuid() not in self._wi_init_seeds:
                self._log.info("Querying neighbors to: %s", p)
                self.working_index.add_many_descriptors(
                    self.nn_index.nn(p, n=self.pos_seed_neighbors)[0]
                )
                self._wi_init_seeds.add(p.uuid())
        for p in self.positive_descriptors:
            if p.uuid() not in self._wi_init_seeds:
                self._log.info("Querying neighbors to: %s", p)
                self.working_index.add_many_descriptors(
                    self.nn_index.nn(p, n=self.pos_seed_neighbors)[0]
                )
                self._wi_init_seeds.add(p.uuid())

        # Make new relevancy index
        self._log.info("Creating new relevancy index over working index.")
        #: :type: smqtk.algorithms.relevancy_index.RelevancyIndex
        self.rel_index = plugin.from_plugin_config(self.rel_index_config,
                                                   get_relevancy_index_impls)
        self.rel_index.build_index(self.working_index.iterdescriptors())
예제 #50
0
def train_classifier_iqr(config, iqr_state_fp):
    log = logging.getLogger(__name__)

    #: :type: smqtk.algorithms.SupervisedClassifier
    classifier = from_plugin_config(config['classifier'], get_classifier_impls)

    if not isinstance(classifier, SupervisedClassifier):
        raise RuntimeError("Configured classifier must be of the "
                           "SupervisedClassifier type in order to train.")

    # Get pos/neg descriptors out of iqr state zip
    z_file = open(iqr_state_fp, 'r')
    z = zipfile.ZipFile(z_file)
    if len(z.namelist()) != 1:
        raise RuntimeError("Invalid IqrState file!")
    iqrs = json.loads(z.read(z.namelist()[0]))
    if len(iqrs) != 2:
        raise RuntimeError("Invalid IqrState file!")
    if 'pos' not in iqrs or 'neg' not in iqrs:
        raise RuntimeError("Invalid IqrState file!")

    log.info("Loading pos/neg descriptors")
    #: :type: list[smqtk.representation.DescriptorElement]
    pos = []
    #: :type: list[smqtk.representation.DescriptorElement]
    neg = []
    i = 0
    for v in set(map(tuple, iqrs['pos'])):
        d = DescriptorMemoryElement('train', i)
        d.set_vector(numpy.array(v))
        pos.append(d)
        i += 1
    for v in set(map(tuple, iqrs['neg'])):
        d = DescriptorMemoryElement('train', i)
        d.set_vector(numpy.array(v))
        neg.append(d)
        i += 1
    log.info('    positive -> %d', len(pos))
    log.info('    negative -> %d', len(neg))

    classifier.train({'positive': pos}, negatives=neg)
예제 #51
0
def main():
    usage = "%prog [options] GLOB [ GLOB [ ... ] ]"
    description = "Add a set of local system files to a data set via " \
                  "explicit paths or shell-style glob strings."

    parser = bin_utils.SMQTKOptParser(usage, description=description)
    parser.add_option('-c', '--config',
                      help="Path to the JSON configuration file")
    parser.add_option('--output-config',
                      help="Optional path to output a default configuration "
                           "file to. This output file should be modified and "
                           "used for this executable.")
    parser.add_option('-v', '--verbose', action='store_true', default=False,
                      help='Add debug messaged to output logging.')
    opts, args = parser.parse_args()

    bin_utils.initialize_logging(logging.getLogger(),
                                 logging.INFO - (10*opts.verbose))
    log = logging.getLogger("main")

    # output configuration dictionary when asked for.
    bin_utils.output_config(opts.output_config, default_config(), log)

    with open(opts.config, 'r') as f:
        config = json.load(f)

    #: :type: smqtk.representation.DataSet
    ds = plugin.from_plugin_config(config['data_set'], get_data_set_impls)
    log.debug("Script arguments:\n%s" % args)

    def ingest_file(fp):
        ds.add_data(DataFileElement(fp))

    for f in args:
        f = osp.expanduser(f)
        if osp.isfile(f):
            ingest_file(f)
        else:
            log.debug("Expanding glob: %s" % f)
            for g in glob.glob(f):
                ingest_file(g)
예제 #52
0
    def initialize(self):
        """
        Initialize working index based on currently set positive exemplar data.

        This takes into account the currently set positive data descriptors as
        well as positively adjudicated descriptors from the lifetime of this
        session.

        :raises RuntimeError: There are no positive example descriptors in this
            session to use as a basis for querying.

        """
        if len(self.ex_pos_descriptors) + \
                len(self.positive_descriptors) <= 0:
            raise RuntimeError("No positive descriptors to query the neighbor "
                               "index with.")
        # Not clearing index because this step is intended to be additive

        # build up new working index
        # TODO: Only query using new positives since previous queries
        for p in self.ex_pos_descriptors.itervalues():
            if p.uuid() not in self._wi_init_seeds:
                self._log.info("Querying neighbors to: %s", p)
                self.working_index.add_many_descriptors(
                    self.nn_index.nn(p, n=self.pos_seed_neighbors)[0])
                self._wi_init_seeds.add(p.uuid())
        for p in self.positive_descriptors:
            if p.uuid() not in self._wi_init_seeds:
                self._log.info("Querying neighbors to: %s", p)
                self.working_index.add_many_descriptors(
                    self.nn_index.nn(p, n=self.pos_seed_neighbors)[0])
                self._wi_init_seeds.add(p.uuid())

        # Make new relevancy index
        self._log.info("Creating new relevancy index over working index.")
        #: :type: smqtk.algorithms.relevancy_index.RelevancyIndex
        self.rel_index = plugin.from_plugin_config(self.rel_index_config,
                                                   get_relevancy_index_impls)
        self.rel_index.build_index(self.working_index.iterdescriptors())
예제 #53
0
def main():
    parser = cli_parser()
    args = parser.parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    log.debug("Script arguments:\n%s" % args)

    def iter_input_elements():
        for f in args.input_files:
            f = osp.expanduser(f)
            if osp.isfile(f):
                yield DataFileElement(f)
            else:
                log.debug("Expanding glob: %s" % f)
                for g in glob.glob(f):
                    yield DataFileElement(g)

    log.info("Adding elements to data set")
    #: :type: smqtk.representation.DataSet
    ds = plugin.from_plugin_config(config['data_set'], get_data_set_impls())
    ds.add_data(*iter_input_elements())
예제 #54
0
def main():
    parser = cli_parser()
    args = parser.parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    log.debug("Script arguments:\n%s" % args)

    def iter_input_elements():
        for f in args.input_files:
            f = osp.expanduser(f)
            if osp.isfile(f):
                yield DataFileElement(f)
            else:
                log.debug("Expanding glob: %s" % f)
                for g in glob.glob(f):
                    yield DataFileElement(g)

    log.info("Adding elements to data set")
    #: :type: smqtk.representation.DataSet
    ds = plugin.from_plugin_config(config['data_set'], get_data_set_impls())
    ds.add_data(*iter_input_elements())
예제 #55
0
def main():
    parser = cli_parser()
    args = parser.parse_args()
    config = bin_utils.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    output_filepath = args.output_filepath
    overwrite = args.overwrite

    if not args.input_file:
        log.error("Failed to provide an input file path")
        exit(1)
    elif not os.path.isfile(args.input_file):
        log.error("Given path does not point to a file.")
        exit(1)

    input_filepath = args.input_file
    data_element = DataFileElement(input_filepath)

    factory = DescriptorElementFactory.from_config(
        config['descriptor_factory'])
    #: :type: smqtk.algorithms.descriptor_generator.DescriptorGenerator
    cd = plugin.from_plugin_config(config['content_descriptor'],
                                   get_descriptor_generator_impls())
    descr_elem = cd.compute_descriptor(data_element, factory, overwrite)
    vec = descr_elem.vector()

    if vec is None:
        log.error("Failed to generate a descriptor vector for the input data!")

    if output_filepath:
        numpy.save(output_filepath, vec)
    else:
        # Construct string, because numpy
        s = []
        # noinspection PyTypeChecker
        for f in vec:
            s.append('%15f' % f)
        print(' '.join(s))
예제 #56
0
파일: iqr_session.py 프로젝트: mrG7/SMQTK
    def initialize(self):
        """
        Initialize working index based on currently set positive exemplar data.

        This takes into account the currently set positive data descriptors as
        well as positively adjudicated descriptors from the lifetime of this
        session.

        :raises RuntimeError: There are no positive example descriptors in this
            session to use as a basis for querying.

        """
        if len(self.ex_pos_descriptors) + \
                len(self.positive_descriptors) <= 0:
            raise RuntimeError("No positive descriptors to query the neighbor "
                               "index with.")
        # Clear the current working index so we can put different things in it
        self._log.info("Clearing working index")
        self.working_index.clear()

        # build up new working index
        for p in self.ex_pos_descriptors.itervalues():
            self._log.info("Querying neighbors to: %s", p)
            self.working_index.add_many_descriptors(
                self.nn_index.nn(p, n=self.pos_seed_neighbors)[0]
            )
        for p in self.positive_descriptors:
            self._log.info("Querying neighbors to: %s", p)
            self.working_index.add_many_descriptors(
                self.nn_index.nn(p, n=self.pos_seed_neighbors)[0]
            )

        # Make new relevancy index
        self._log.info("Creating new relevancy index over working index.")
        #: :type: smqtk.algorithms.relevancy_index.RelevancyIndex
        self.rel_index = plugin.from_plugin_config(self.rel_index_config,
                                                   get_relevancy_index_impls)
        self.rel_index.build_index(self.working_index.iterdescriptors())