Exemplo n.º 1
0
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        This method should not be called via super unless and instance of the
        class is desired.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: MRPTNearestNeighborsIndex

        """
        if merge_default:
            cfg = cls.get_default_config()
            merge_dict(cfg, config_dict)
        else:
            cfg = config_dict

        cfg['descriptor_set'] = \
            from_config_dict(cfg['descriptor_set'],
                             DescriptorSet.get_impls())

        return super(MRPTNearestNeighborsIndex, cls).from_config(cfg, False)
Exemplo n.º 2
0
    def get_default_config(cls):
        """
        Generate and return a default configuration dictionary for this class.
        This will be primarily used for generating what the configuration
        dictionary would look like for this class without instantiating it.

        :return: Default configuration dictionary for the class.
        :rtype: dict

        """
        c = super(NearestNeighborServiceServer, cls).get_default_config()
        merge_dict(
            c, {
                "descriptor_factory":
                DescriptorElementFactory.get_default_config(),
                "descriptor_generator":
                make_default_config(DescriptorGenerator.get_impls()),
                "nn_index":
                make_default_config(NearestNeighborsIndex.get_impls()),
                "descriptor_set":
                make_default_config(DescriptorSet.get_impls()),
                "update_descriptor_set":
                False,
            })
        return c
Exemplo n.º 3
0
def default_config():
    return {
        'plugins': {
            'classifier':
            make_default_config(SupervisedClassifier.get_impls()),
            'classification_factory':
            ClassificationElementFactory.get_default_config(),
            'descriptor_set':
            make_default_config(DescriptorSet.get_impls())
        },
        'utility': {
            'train': False,
            'csv_filepath': 'CHAMGEME :: PATH :: a csv file',
            'output_plot_pr': None,
            'output_plot_roc': None,
            'output_plot_confusion_matrix': None,
            'output_uuid_confusion_matrix': None,
            'curve_confidence_interval': False,
            'curve_confidence_interval_alpha': 0.4,
        },
        "parallelism": {
            "descriptor_fetch_cores": 4,
            # DEPRECATED
            "classification_cores": None,
        },
    }
def default_config():

    # Trick for mixing in our Configurable class API on top of scikit-learn's
    # MiniBatchKMeans class in order to introspect construction parameters.
    # We never construct this class so we do not need to implement "pure
    # virtual" instance methods.
    # noinspection PyAbstractClass
    class MBKTemp(MiniBatchKMeans, Configurable):
        pass

    c: Dict[str, Any] = {
        "minibatch_kmeans_params": MBKTemp.get_default_config(),
        "descriptor_set": make_default_config(DescriptorSet.get_impls()),
        # Number of descriptors to run an initial fit with. This brings the
        # advantage of choosing a best initialization point from multiple.
        "initial_fit_size": 0,
        # Path to save generated KMeans centroids
        "centroids_output_filepath_npy": "centroids.npy"
    }

    # Change/Remove some KMeans params for more appropriate defaults
    del c['minibatch_kmeans_params']['compute_labels']
    del c['minibatch_kmeans_params']['verbose']
    c['minibatch_kmeans_params']['random_state'] = 0

    return c
Exemplo n.º 5
0
def default_config():
    return {
        "plugins": {
            "supervised_classifier":
            make_default_config(SupervisedClassifier.get_impls()),
            "descriptor_set":
            make_default_config(DescriptorSet.get_impls()),
        },
        "cross_validation": {
            "truth_labels": None,
            "num_folds": 6,
            "random_seed": None,
            "classification_use_multiprocessing": True,
        },
        "pr_curves": {
            "enabled": True,
            "show": False,
            "output_directory": None,
            "file_prefix": None,
        },
        "roc_curves": {
            "enabled": True,
            "show": False,
            "output_directory": None,
            "file_prefix": None,
        },
    }
def main():
    args = cli_parser().parse_args()
    config = utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    output_filepath = args.output_map
    if not output_filepath:
        raise ValueError("No path given for output map file (pickle).")

    #: :type: smqtk.representation.DescriptorSet
    descr_set = from_config_dict(config['descriptor_set'],
                                 DescriptorSet.get_impls())
    mbkm = MiniBatchKMeans(verbose=args.verbose,
                           compute_labels=False,
                           **config['minibatch_kmeans_params'])
    initial_fit_size = int(config['initial_fit_size'])

    d_classes = mb_kmeans_build_apply(descr_set, mbkm, initial_fit_size)

    log.info("Saving KMeans centroids to: %s",
             config['centroids_output_filepath_npy'])
    numpy.save(config['centroids_output_filepath_npy'], mbkm.cluster_centers_)

    log.info("Saving result classification map to: %s", output_filepath)
    safe_create_dir(os.path.dirname(output_filepath))
    with open(output_filepath, 'wb') as f:
        cPickle.dump(d_classes, f, -1)

    log.info("Done")
Exemplo n.º 7
0
    def get_default_config(cls):
        c = super(SmqtkClassifierService, cls).get_default_config()

        c[cls.CONFIG_ENABLE_CLASSIFIER_REMOVAL] = False

        # Static classifier configurations
        c[cls.CONFIG_CLASSIFIER_COLLECTION] = \
            ClassifierCollection.get_default_config()
        # Classification element factory for new classification results.
        c[cls.CONFIG_CLASSIFICATION_FACTORY] = \
            ClassificationElementFactory.get_default_config()
        # Descriptor generator for new content
        c[cls.CONFIG_DESCRIPTOR_GENERATOR] = make_default_config(
            DescriptorGenerator.get_impls()
        )
        # Descriptor factory for new content descriptors
        c[cls.CONFIG_DESCRIPTOR_FACTORY] = \
            DescriptorElementFactory.get_default_config()
        # Optional Descriptor set for "included" descriptors referenceable by
        # UID.
        c[cls.CONFIG_DESCRIPTOR_SET] = make_default_config(
            DescriptorSet.get_impls()
        )
        # from-IQR-state *supervised* classifier configuration
        c[cls.CONFIG_IQR_CLASSIFIER] = make_default_config(
            SupervisedClassifier.get_impls()
        )
        c[cls.CONFIG_IMMUTABLE_LABELS] = []

        return c
Exemplo n.º 8
0
    def get_default_config(cls):
        """
        Generate and return a default configuration dictionary for this class.
        This will be primarily used for generating what the configuration
        dictionary would look like for this class without instantiating it.

        By default, we observe what this class's constructor takes as
        arguments, turning those argument names into configuration dictionary
        keys. If any of those arguments have defaults, we will add those
        values into the configuration dictionary appropriately. The dictionary
        returned should only contain JSON compliant value types.

        It is not be guaranteed that the configuration dictionary returned
        from this method is valid for construction of an instance of this
        class.

        :return: Default configuration dictionary for the class.
        :rtype: dict

        """
        default = super(MRPTNearestNeighborsIndex, cls).get_default_config()

        di_default = make_default_config(DescriptorSet.get_impls())
        default['descriptor_set'] = di_default

        return default
Exemplo n.º 9
0
def get_default_config():
    return {
        'plugins': {
            'descriptor_set': make_default_config(DescriptorSet.get_impls()),
            'nn_index': make_default_config(NearestNeighborsIndex.get_impls())
        }
    }
Exemplo n.º 10
0
def default_config():
    return {
        "descriptor_generator":
        make_default_config(DescriptorGenerator.get_impls()),
        "descriptor_factory":
        DescriptorElementFactory.get_default_config(),
        "descriptor_set":
        make_default_config(DescriptorSet.get_impls()),
        "optional_data_set":
        make_default_config(DataSet.get_impls())
    }
Exemplo n.º 11
0
    def __init__(self, json_config):
        super(SmqtkClassifierService, self).__init__(json_config)

        self.enable_classifier_removal = \
            bool(json_config[self.CONFIG_ENABLE_CLASSIFIER_REMOVAL])

        self.immutable_labels = set(json_config[self.CONFIG_IMMUTABLE_LABELS])

        # Convert configuration into SMQTK plugin instances.
        #   - Static classifier configurations.
        #       - Skip the example config key
        #   - Classification element factory
        #   - Descriptor generator
        #   - Descriptor element factory
        #   - from-IQR-state classifier configuration
        #       - There must at least be the default key defined for when no
        #         specific classifier type is specified at state POST.

        # Classifier collection + factor
        self.classification_factory = \
            ClassificationElementFactory.from_config(
                json_config[self.CONFIG_CLASSIFICATION_FACTORY]
            )
        #: :type: ClassifierCollection
        self.classifier_collection = ClassifierCollection.from_config(
            json_config[self.CONFIG_CLASSIFIER_COLLECTION]
        )

        # Descriptor generator + factory
        self.descriptor_factory = DescriptorElementFactory.from_config(
            json_config[self.CONFIG_DESCRIPTOR_FACTORY]
        )
        #: :type: smqtk.algorithms.DescriptorGenerator
        self.descriptor_gen = from_config_dict(
            json_config[self.CONFIG_DESCRIPTOR_GENERATOR],
            smqtk.algorithms.DescriptorGenerator.get_impls()
        )

        # Descriptor set bundled for classification-by-UID.
        try:
            self.descriptor_set = from_config_dict(
                json_config.get(self.CONFIG_DESCRIPTOR_SET, {}),
                DescriptorSet.get_impls()
            )
        except ValueError:
            # Default empty set.
            self.descriptor_set = MemoryDescriptorSet()

        # Classifier config for uploaded IQR states.
        self.iqr_state_classifier_config = \
            json_config[self.CONFIG_IQR_CLASSIFIER]

        self.add_routes()
Exemplo n.º 12
0
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        This method should not be called via super unless and instance of the
        class is desired.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: LSHNearestNeighborIndex

        """
        # Controlling merge here so we can control known comment stripping from
        # default config.
        if merge_default:
            merged = cls.get_default_config()
            merge_dict(merged, config_dict)
        else:
            merged = config_dict

        merged['lsh_functor'] = \
            from_config_dict(merged['lsh_functor'], LshFunctor.get_impls())
        merged['descriptor_set'] = \
            from_config_dict(merged['descriptor_set'],
                             DescriptorSet.get_impls())

        # Hash index may be None for a default at-query-time linear indexing
        if merged['hash_index'] and merged['hash_index']['type']:
            merged['hash_index'] = \
                from_config_dict(merged['hash_index'],
                                 HashIndex.get_impls())
        else:
            cls.get_logger().debug("No HashIndex impl given. Passing "
                                   "``None``.")
            merged['hash_index'] = None

        # remove possible comment added by default generator
        if 'hash_index_comment' in merged:
            del merged['hash_index_comment']

        merged['hash2uuids_kvstore'] = \
            from_config_dict(merged['hash2uuids_kvstore'],
                             KeyValueStore.get_impls())

        return super(LSHNearestNeighborIndex, cls).from_config(merged, False)
Exemplo n.º 13
0
def default_config():
    return {
        "utility": {
            "report_interval": 1.0,
            "use_multiprocessing": False,
        },
        "plugins": {
            "descriptor_set": make_default_config(DescriptorSet.get_impls()),
            "lsh_functor": make_default_config(LshFunctor.get_impls()),
            "hash2uuid_kvstore":
            make_default_config(KeyValueStore.get_impls()),
        },
    }
Exemplo n.º 14
0
    def from_config(cls, config_dict, merge_default=True):
        """
        Instantiate a new instance of this class given the configuration
        JSON-compliant dictionary encapsulating initialization arguments.

        This method should not be called via super unless and instance of the
        class is desired.

        :param config_dict: JSON compliant dictionary encapsulating
            a configuration.
        :type config_dict: dict

        :param merge_default: Merge the given configuration on top of the
            default provided by ``get_default_config``.
        :type merge_default: bool

        :return: Constructed instance from the provided config.
        :rtype: LSHNearestNeighborIndex

        """
        if merge_default:
            cfg = cls.get_default_config()
            merge_dict(cfg, config_dict)
        else:
            cfg = config_dict

        cfg['descriptor_set'] = from_config_dict(cfg['descriptor_set'],
                                                 DescriptorSet.get_impls())
        cfg['uid2idx_kvs'] = from_config_dict(cfg['uid2idx_kvs'],
                                              KeyValueStore.get_impls())
        cfg['idx2uid_kvs'] = from_config_dict(cfg['idx2uid_kvs'],
                                              KeyValueStore.get_impls())

        if (cfg['index_element'] and cfg['index_element']['type']):
            index_element = from_config_dict(cfg['index_element'],
                                             DataElement.get_impls())
            cfg['index_element'] = index_element
        else:
            cfg['index_element'] = None

        if (cfg['index_param_element'] and cfg['index_param_element']['type']):
            index_param_element = from_config_dict(cfg['index_param_element'],
                                                   DataElement.get_impls())
            cfg['index_param_element'] = index_param_element
        else:
            cfg['index_param_element'] = None

        return super(FaissNearestNeighborsIndex, cls).from_config(cfg, False)
Exemplo n.º 15
0
def main():
    # Print help and exit if no arguments were passed
    if len(sys.argv) == 1:
        get_cli_parser().print_help()
        sys.exit(1)

    args = get_cli_parser().parse_args()
    config = utility_main_helper(get_default_config, args)

    log = logging.getLogger(__name__)
    log.debug('Showing debug messages.')

    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(config['plugins']['descriptor_set'],
                                      DescriptorSet.get_impls())
    #: :type: smqtk.algorithms.NearestNeighborsIndex
    nearest_neighbor_index = from_config_dict(
        config['plugins']['nn_index'], NearestNeighborsIndex.get_impls())

    # noinspection PyShadowingNames
    def nearest_neighbors(descriptor, n):
        if n == 0:
            n = len(nearest_neighbor_index)

        uuids, descriptors = nearest_neighbor_index.nn(descriptor, n)
        # Strip first result (itself) and create list of (uuid, distance)
        return list(zip([x.uuid() for x in uuids[1:]], descriptors[1:]))

    if args.uuid_list is not None and not os.path.exists(args.uuid_list):
        log.error('Invalid file list path: %s', args.uuid_list)
        exit(103)
    elif args.num < 0:
        log.error('Number of nearest neighbors must be >= 0')
        exit(105)

    if args.uuid_list is not None:
        with open(args.uuid_list, 'r') as infile:
            for line in infile:
                descriptor = descriptor_set.get_descriptor(line.strip())
                print(descriptor.uuid())
                for neighbor in nearest_neighbors(descriptor, args.num):
                    print('%s,%f' % neighbor)
    else:
        for (uuid, descriptor) in descriptor_set.iteritems():
            print(uuid)
            for neighbor in nearest_neighbors(descriptor, args.num):
                print('%s,%f' % neighbor)
Exemplo n.º 16
0
def default_config():
    return {
        "utility": {
            "classify_overwrite": False,
            "parallel": {
                "use_multiprocessing": False,
                "index_extraction_cores": None,
                "classification_cores": None,
            }
        },
        "plugins": {
            "classifier":
            make_default_config(Classifier.get_impls()),
            "classification_factory":
            make_default_config(ClassificationElement.get_impls()),
            "descriptor_set":
            make_default_config(DescriptorSet.get_impls()),
        }
    }
Exemplo n.º 17
0
def cli_build(config_filepath):
    """
    Build a new nearest-neighbors index from the configured descriptor set's
    contents.
    """
    config_dict, success = load_config(config_filepath,
                                       defaults=build_default_config())
    # Defaults are insufficient so we assert that the configuration file was
    # (successfully) loaded.
    if not success:
        raise RuntimeError("Failed to load configuration file.")

    descr_set = from_config_dict(config_dict['descriptor_set'],
                                 DescriptorSet.get_impls())

    nn_index = from_config_dict(config_dict['neighbor_index'],
                                NearestNeighborsIndex.get_impls())

    # TODO: reduced amount used for building ("training") and remainder used
    #       for update.
    nn_index.build_index(descr_set)
Exemplo n.º 18
0
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(config['plugins']['descriptor_set'],
                                      DescriptorSet.get_impls())

    labels_filepath = args.f
    output_filepath = args.o

    # Run through labeled UUIDs in input file, getting the descriptor from the
    # configured index, applying the appropriate integer label and then writing
    # the formatted line out to the output file.
    input_uuid_labels = csv.reader(open(labels_filepath))

    with open(output_filepath, 'w') as ofile:
        label2int = {}
        next_int = 1
        uuids, labels = list(zip(*input_uuid_labels))

        log.info("Scanning input descriptors and labels")
        for i, (l, d) in enumerate(
                zip(labels, descriptor_set.get_many_descriptors(uuids))):
            log.debug("%d %s", i, d.uuid())
            if l not in label2int:
                label2int[l] = next_int
                next_int += 1
            ofile.write("%d " % label2int[l] + ' '.join([
                "%d:%.12f" % (j + 1, f)
                for j, f in enumerate(d.vector()) if f != 0.0
            ]) + '\n')

    log.info("Integer label association:")
    for i, l in sorted((i, l) for l, i in six.iteritems(label2int)):
        log.info('\t%d :: %s', i, l)
Exemplo n.º 19
0
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    uuids_list_filepath = config['uuids_list_filepath']

    log.info("Initializing ITQ functor")
    #: :type: smqtk.algorithms.nn_index.lsh.functors.itq.ItqFunctor
    functor = ItqFunctor.from_config(config['itq_config'])

    log.info("Initializing DescriptorSet [type=%s]",
             config['descriptor_set']['type'])
    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(
        config['descriptor_set'],
        DescriptorSet.get_impls(),
    )

    if uuids_list_filepath and os.path.isfile(uuids_list_filepath):

        def uuids_iter():
            with open(uuids_list_filepath) as f:
                for l in f:
                    yield l.strip()

        log.info("Loading UUIDs list from file: %s", uuids_list_filepath)
        d_iter = descriptor_set.get_many_descriptors(uuids_iter())
    else:
        log.info("Using UUIDs from loaded DescriptorSet (count=%d)",
                 len(descriptor_set))
        d_iter = descriptor_set

    log.info("Fitting ITQ model")
    functor.fit(d_iter)
    log.info("Done")
Exemplo n.º 20
0
def default_config():
    return {
        'plugins': {
            'descriptor_set': make_default_config(DescriptorSet.get_impls()),
        }
    }
Exemplo n.º 21
0
def run_file_list(c,
                  filelist_filepath,
                  checkpoint_filepath,
                  batch_size=None,
                  check_image=False):
    """
    Top level function handling configuration and inputs/outputs.

    :param c: Configuration dictionary (JSON)
    :type c: dict

    :param filelist_filepath: Path to a text file that lists paths to data
        files, separated by new lines.
    :type filelist_filepath: str

    :param checkpoint_filepath: Output file to which we write input filepath to
        SHA1 (UUID) relationships.
    :type checkpoint_filepath:

    :param batch_size: Optional batch size (None default) of data elements to
        process / descriptors to compute at a time. This causes files and
        stores to be written to incrementally during processing instead of
        one single batch transaction at a time.
    :type batch_size:

    :param check_image: Enable checking image loading from file before queueing
        that file for processing. If the check fails, the file is skipped
        instead of a halting exception being raised.
    :type check_image: bool

    """
    log = logging.getLogger(__name__)

    file_paths = [line.strip() for line in open(filelist_filepath)]

    log.info("Making descriptor factory")
    factory = DescriptorElementFactory.from_config(c['descriptor_factory'])

    log.info("Making descriptor index")
    descriptor_set = cast(
        DescriptorSet,
        from_config_dict(c['descriptor_set'], DescriptorSet.get_impls()))

    # ``data_set`` added to within the ``iter_valid_elements`` function.
    data_set: Optional[DataSet] = None
    if c['optional_data_set']['type'] is None:
        log.info("Not saving loaded data elements to data set")
    else:
        log.info("Initializing data set to append to")
        data_set = cast(
            DataSet,
            from_config_dict(c['optional_data_set'], DataSet.get_impls()))

    log.info("Making descriptor generator '%s'",
             c['descriptor_generator']['type'])
    generator = cast(
        DescriptorGenerator,
        from_config_dict(c['descriptor_generator'],
                         DescriptorGenerator.get_impls()))

    def iter_valid_elements():
        def is_valid(file_path):
            e = DataFileElement(file_path)

            if is_valid_element(
                    e,
                    valid_content_types=generator.valid_content_types(),
                    check_image=check_image):
                return e
            else:
                return False

        data_elements: Deque[DataFileElement] = collections.deque()
        valid_files_filter = parallel.parallel_map(is_valid,
                                                   file_paths,
                                                   name="check-file-type",
                                                   use_multiprocessing=True)
        for dfe in valid_files_filter:
            if dfe:
                yield dfe
                if data_set is not None:
                    data_elements.append(dfe)
                    if batch_size and len(data_elements) == batch_size:
                        log.debug(
                            "Adding data element batch to set (size: %d)",
                            len(data_elements))
                        data_set.add_data(*data_elements)
                        data_elements.clear()
        # elements only collected if we have a data-set configured, so add any
        # still in the deque to the set
        if data_set is not None and data_elements:
            log.debug("Adding data elements to set (size: %d",
                      len(data_elements))
            data_set.add_data(*data_elements)

    log.info("Computing descriptors")
    m = compute_many_descriptors(
        iter_valid_elements(),
        generator,
        factory,
        descriptor_set,
        batch_size=batch_size,
    )

    # Recording computed file paths and associated file UUIDs (SHA1)
    cf = open(checkpoint_filepath, 'w')
    cf_writer = csv.writer(cf)
    try:
        pr = ProgressReporter(log.debug, 1.0).start()
        for de, descr in m:
            # We know that we are using DataFileElements going into the
            # compute_many_descriptors, so we can assume that's what comes out
            # of it as well.
            # noinspection PyProtectedMember
            cf_writer.writerow([de._filepath, descr.uuid()])
            pr.increment_report()
        pr.report()
    finally:
        del cf_writer
        cf.close()

    log.info("Done")
Exemplo n.º 22
0
def classifier_kfold_validation():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #
    # Load configurations / Setup data
    #
    pr_enabled = config['pr_curves']['enabled']
    pr_output_dir = config['pr_curves']['output_directory']
    pr_file_prefix = config['pr_curves']['file_prefix'] or ''
    pr_show = config['pr_curves']['show']

    roc_enabled = config['roc_curves']['enabled']
    roc_output_dir = config['roc_curves']['output_directory']
    roc_file_prefix = config['roc_curves']['file_prefix'] or ''
    roc_show = config['roc_curves']['show']

    log.info("Initializing DescriptorSet (%s)",
             config['plugins']['descriptor_set']['type'])
    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(config['plugins']['descriptor_set'],
                                      DescriptorSet.get_impls())
    log.info("Loading classifier configuration")
    #: :type: dict
    classifier_config = config['plugins']['supervised_classifier']

    # Always use in-memory ClassificationElement since we are retraining the
    # classifier and don't want possible element caching
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory(
        MemoryClassificationElement, {})

    log.info("Loading truth data")
    #: :type: list[str]
    uuids = []
    #: :type: list[str]
    truth_labels = []
    with open(config['cross_validation']['truth_labels']) as f:
        f_csv = csv.reader(f)
        for row in f_csv:
            uuids.append(row[0])
            truth_labels.append(row[1])
    #: :type: numpy.ndarray[str]
    uuids = numpy.array(uuids)
    #: :type: numpy.ndarray[str]
    truth_labels = numpy.array(truth_labels)

    #
    # Cross validation
    #
    kfolds = sklearn.model_selection.StratifiedKFold(
        n_splits=config['cross_validation']['num_folds'],
        shuffle=True,
        random_state=config['cross_validation']['random_seed'],
    ).split(numpy.zeros(len(truth_labels)), truth_labels)
    """
    Truth and classification probability results for test data per fold.
    Format:
        {
            0: {
                '<label>':  {
                    "truth": [...],   # Parallel truth and classification
                    "proba": [...],   # probability values
                },
                ...
            },
            ...
        }
    """
    fold_data: Dict[int, Any] = {}

    i = 0
    for train, test in kfolds:
        log.info("Fold %d", i)
        log.info("-- %d training examples", len(train))
        log.info("-- %d test examples", len(test))
        fold_data[i] = {}

        log.info("-- creating classifier")
        classifier = cast(
            SupervisedClassifier,
            from_config_dict(classifier_config,
                             SupervisedClassifier.get_impls()))

        log.info("-- gathering descriptors")
        pos_map: Dict[str, List[DescriptorElement]] = {}
        for idx in train:
            if truth_labels[idx] not in pos_map:
                pos_map[truth_labels[idx]] = []
            pos_map[truth_labels[idx]].append(
                descriptor_set.get_descriptor(uuids[idx]))

        log.info("-- Training classifier")
        classifier.train(pos_map)

        log.info("-- Classifying test set")
        c_iter = classifier.classify_elements(
            (descriptor_set.get_descriptor(uuids[idx]) for idx in test),
            classification_factory,
        )
        uuid2c = dict((c.uuid, c.get_classification()) for c in c_iter)

        log.info("-- Pairing truth and computed probabilities")
        # Only considering positive labels
        for t_label in pos_map:
            fold_data[i][t_label] = {
                "truth": [L == t_label for L in truth_labels[test]],
                "proba": [uuid2c[uuid][t_label] for uuid in uuids[test]]
            }

        i += 1

    #
    # Curve generation
    #
    if pr_enabled:
        make_pr_curves(fold_data, pr_output_dir, pr_file_prefix, pr_show)
    if roc_enabled:
        make_roc_curves(fold_data, roc_output_dir, roc_file_prefix, roc_show)
Exemplo n.º 23
0
def build_default_config():
    return {
        'descriptor_set': make_default_config(DescriptorSet.get_impls()),
        'neighbor_index':
        make_default_config(NearestNeighborsIndex.get_impls()),
    }
Exemplo n.º 24
0
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    # Deprecations
    if (config.get('parallelism', {}).get('classification_cores', None)
            is not None):
        warnings.warn(
            "Usage of 'classification_cores' is deprecated. "
            "Classifier parallelism is not defined on a "
            "per-implementation basis. See classifier "
            "implementation parameterization.",
            category=DeprecationWarning)

    #
    # Initialize stuff from configuration
    #
    #: :type: smqtk.algorithms.Classifier
    classifier = from_config_dict(config['plugins']['classifier'],
                                  SupervisedClassifier.get_impls())
    #: :type: ClassificationElementFactory
    classification_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])
    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(config['plugins']['descriptor_set'],
                                      DescriptorSet.get_impls())

    uuid2label_filepath = config['utility']['csv_filepath']
    do_train = config['utility']['train']
    output_uuid_cm = config['utility']['output_uuid_confusion_matrix']
    plot_filepath_pr = config['utility']['output_plot_pr']
    plot_filepath_roc = config['utility']['output_plot_roc']
    plot_filepath_cm = config['utility']['output_plot_confusion_matrix']
    plot_ci = config['utility']['curve_confidence_interval']
    plot_ci_alpha = config['utility']['curve_confidence_interval_alpha']

    #
    # Construct mapping of label to the DescriptorElement instances for that
    # described by that label.
    #
    log.info("Loading descriptors by UUID")

    def iter_uuid_label():
        """ Iterate through UUIDs in specified file """
        with open(uuid2label_filepath) as uuid2label_file:
            reader = csv.reader(uuid2label_file)
            for r in reader:
                # TODO: This will need to be updated to handle multiple labels
                #       per descriptor.
                yield r[0], r[1]

    def get_descr(r):
        """ Fetch descriptors from configured index """
        uuid, truth_label = r
        return truth_label, descriptor_set.get_descriptor(uuid)

    tlabel_element_iter = parallel.parallel_map(
        get_descr,
        iter_uuid_label(),
        name="cmv_get_descriptors",
        use_multiprocessing=True,
        cores=config['parallelism']['descriptor_fetch_cores'],
    )

    # Map of truth labels to descriptors of labeled data
    #: :type: dict[str, list[smqtk.representation.DescriptorElement]]
    tlabel2descriptors = {}
    for tlabel, d in tlabel_element_iter:
        tlabel2descriptors.setdefault(tlabel, []).append(d)

    # Train classifier if the one given has a ``train`` method and training
    # was turned enabled.
    if do_train:
        log.info("Training supervised classifier model")
        classifier.train(tlabel2descriptors)
        exit(0)

    #
    # Apply classifier to descriptors for predictions
    #

    # Truth label to predicted classification results
    #: :type: dict[str, set[smqtk.representation.ClassificationElement]]
    tlabel2classifications = {}
    for tlabel, descriptors in six.iteritems(tlabel2descriptors):
        tlabel2classifications[tlabel] = \
            set(classifier.classify_elements(descriptors,
                                             classification_factory))
    log.info("Truth label counts:")
    for l in sorted(tlabel2classifications):
        log.info("  %s :: %d", l, len(tlabel2classifications[l]))

    #
    # Confusion Matrix
    #
    conf_mat, labels = gen_confusion_matrix(tlabel2classifications)
    log.info("Confusion_matrix")
    log_cm(log.info, conf_mat, labels)
    if plot_filepath_cm:
        plot_cm(conf_mat, labels, plot_filepath_cm)

    # Confusion Matrix of descriptor UUIDs to output json
    if output_uuid_cm:
        # Top dictionary keys are true labels, inner dictionary keys are UUID
        # predicted labels.
        log.info("Computing UUID Confusion Matrix")
        #: :type: dict[str, dict[collections.Hashable, set]]
        uuid_cm = {}
        for tlabel in tlabel2classifications:
            uuid_cm[tlabel] = collections.defaultdict(set)
            for c in tlabel2classifications[tlabel]:
                uuid_cm[tlabel][c.max_label()].add(c.uuid)
            # convert sets to lists for JSON output.
            for plabel in uuid_cm[tlabel]:
                # noinspection PyTypeChecker
                uuid_cm[tlabel][plabel] = list(uuid_cm[tlabel][plabel])
        with open(output_uuid_cm, 'w') as f:
            log.info("Saving UUID Confusion Matrix: %s", output_uuid_cm)
            json.dump(uuid_cm, f, indent=2, separators=(',', ': '))

    #
    # Create PR/ROC curves via scikit learn tools
    #
    if plot_filepath_pr:
        log.info("Making PR curve")
        make_pr_curves(tlabel2classifications, plot_filepath_pr, plot_ci,
                       plot_ci_alpha)
    if plot_filepath_roc:
        log.info("Making ROC curve")
        make_roc_curves(tlabel2classifications, plot_filepath_roc, plot_ci,
                        plot_ci_alpha)
Exemplo n.º 25
0
def default_config():
    return {
        "itq_config": ItqFunctor.get_default_config(),
        "uuids_list_filepath": None,
        "descriptor_set": make_default_config(DescriptorSet.get_impls()),
    }
Exemplo n.º 26
0
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    # - parallel_map UUIDs to load from the configured index
    # - classify iterated descriptors

    uuids_list_filepath = args.uuids_list
    output_csv_filepath = args.csv_data
    output_csv_header_filepath = args.csv_header
    classify_overwrite = config['utility']['classify_overwrite']

    p_use_multiprocessing = \
        config['utility']['parallel']['use_multiprocessing']
    p_index_extraction_cores = \
        config['utility']['parallel']['index_extraction_cores']
    p_classification_cores = \
        config['utility']['parallel']['classification_cores']

    if not uuids_list_filepath:
        raise ValueError("No uuids_list_filepath specified.")
    elif not os.path.isfile(uuids_list_filepath):
        raise ValueError("Given uuids_list_filepath did not point to a file.")
    if output_csv_header_filepath is None:
        raise ValueError("Need a path to save CSV header labels")
    if output_csv_filepath is None:
        raise ValueError("Need a path to save CSV data.")

    #
    # Initialize configured plugins
    #

    log.info("Initializing descriptor index")
    #: :type: smqtk.representation.DescriptorSet
    descriptor_set = from_config_dict(config['plugins']['descriptor_set'],
                                      DescriptorSet.get_impls())

    log.info("Initializing classification factory")
    c_factory = ClassificationElementFactory.from_config(
        config['plugins']['classification_factory'])

    log.info("Initializing classifier")
    #: :type: smqtk.algorithms.Classifier
    classifier = from_config_dict(config['plugins']['classifier'],
                                  Classifier.get_impls())

    #
    # Setup/Process
    #
    def iter_uuids():
        with open(uuids_list_filepath) as f:
            for l in f:
                yield l.strip()

    def descr_for_uuid(uuid):
        """
        :type uuid: collections.Hashable
        :rtype: smqtk.representation.DescriptorElement
        """
        return descriptor_set.get_descriptor(uuid)

    def classify_descr(d):
        """
        :type d: smqtk.representation.DescriptorElement
        :rtype: smqtk.representation.ClassificationElement
        """
        return classifier.classify_one_element(d, c_factory,
                                               classify_overwrite)

    log.info("Initializing uuid-to-descriptor parallel map")
    #: :type: collections.Iterable[smqtk.representation.DescriptorElement]
    element_iter = parallel.parallel_map(
        descr_for_uuid,
        iter_uuids(),
        use_multiprocessing=p_use_multiprocessing,
        cores=p_index_extraction_cores,
        name="descr_for_uuid",
    )

    log.info("Initializing descriptor-to-classification parallel map")
    #: :type: collections.Iterable[smqtk.representation.ClassificationElement]
    classification_iter = parallel.parallel_map(
        classify_descr,
        element_iter,
        use_multiprocessing=p_use_multiprocessing,
        cores=p_classification_cores,
        name='classify_descr',
    )

    #
    # Write/Output files
    #

    c_labels = classifier.get_labels()

    def make_row(e):
        """
        :type e: smqtk.representation.ClassificationElement
        """
        c_m = e.get_classification()
        return [e.uuid] + [c_m[l] for l in c_labels]

    # column labels file
    log.info("Writing CSV column header file: %s", output_csv_header_filepath)
    safe_create_dir(os.path.dirname(output_csv_header_filepath))
    with open(output_csv_header_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        w.writerow(['uuid'] + [str(cl) for cl in c_labels])

    # CSV file
    log.info("Writing CSV data file: %s", output_csv_filepath)
    safe_create_dir(os.path.dirname(output_csv_filepath))
    pr = cli.ProgressReporter(log.info, 1.0)
    pr.start()
    with open(output_csv_filepath, 'wb') as f_csv:
        w = csv.writer(f_csv)
        for c in classification_iter:
            w.writerow(make_row(c))
            pr.increment_report()
        pr.report()

    log.info("Done")
Exemplo n.º 27
0
def main():
    args = cli_parser().parse_args()
    config = cli.utility_main_helper(default_config, args)
    log = logging.getLogger(__name__)

    #
    # Load configuration contents
    #
    uuid_list_filepath = args.uuids_list
    report_interval = config['utility']['report_interval']
    use_multiprocessing = config['utility']['use_multiprocessing']

    #
    # Checking input parameters
    #
    if (uuid_list_filepath is not None) and \
            not os.path.isfile(uuid_list_filepath):
        raise ValueError("UUIDs list file does not exist!")

    #
    # Loading stuff
    #
    log.info("Loading descriptor index")
    descriptor_set = from_config_dict(config['plugins']['descriptor_set'],
                                      DescriptorSet.get_impls())
    log.info("Loading LSH functor")
    lsh_functor = from_config_dict(config['plugins']['lsh_functor'],
                                   LshFunctor.get_impls())
    log.info("Loading Key/Value store")
    hash2uuids_kvstore = from_config_dict(
        config['plugins']['hash2uuid_kvstore'], KeyValueStore.get_impls())

    # Iterate either over what's in the file given, or everything in the
    # configured index.
    def iter_uuids():
        if uuid_list_filepath:
            log.info("Using UUIDs list file")
            with open(uuid_list_filepath) as f:
                for line in f:
                    yield line.strip()
        else:
            log.info("Using all UUIDs resent in descriptor index")
            for k in descriptor_set.keys():
                yield k

    #
    # Compute codes
    #
    log.info("Starting hash code computation")
    kv_update = {}
    for uuid, hash_int in \
            compute_hash_codes(uuids_for_processing(iter_uuids(),
                                                    hash2uuids_kvstore),
                               descriptor_set, lsh_functor,
                               report_interval,
                               use_multiprocessing, True):
        # Get original value in KV-store if not in update dict.
        if hash_int not in kv_update:
            kv_update[hash_int] = hash2uuids_kvstore.get(hash_int, set())
        kv_update[hash_int] |= {uuid}

    if kv_update:
        log.info("Updating KV store... (%d keys)" % len(kv_update))
        hash2uuids_kvstore.add_many(kv_update)

    log.info("Done")
Exemplo n.º 28
0
    def __init__(self, json_config):
        """
        Initialize application based of supplied JSON configuration

        :param json_config: JSON configuration dictionary
        :type json_config: dict

        """
        super(NearestNeighborServiceServer, self).__init__(json_config)

        self.update_index = json_config['update_descriptor_set']

        # Descriptor factory setup
        self._log.info("Initializing DescriptorElementFactory")
        self.descr_elem_factory = DescriptorElementFactory.from_config(
            self.json_config['descriptor_factory'])

        #: :type: smqtk.representation.DescriptorSet | None
        self.descr_index = None
        if self.update_index:
            self._log.info("Initializing DescriptorSet to update")
            #: :type: smqtk.representation.DescriptorSet | None
            self.descr_index = from_config_dict(json_config['descriptor_set'],
                                                DescriptorSet.get_impls())

        #: :type: smqtk.algorithms.NearestNeighborsIndex
        self.nn_index = from_config_dict(json_config['nn_index'],
                                         NearestNeighborsIndex.get_impls())

        #: :type: smqtk.algorithms.DescriptorGenerator
        self.descriptor_generator_inst = from_config_dict(
            self.json_config['descriptor_generator'],
            DescriptorGenerator.get_impls())

        @self.route("/count", methods=['GET'])
        def count():
            """
            Return the number of elements represented in this index.
            """
            return flask.jsonify(**{
                "count": self.nn_index.count(),
            })

        @self.route("/compute/<path:uri>", methods=["POST"])
        def compute(uri):
            """
            Compute the descriptor for a URI specified data element using the
            configured descriptor generator.

            See ``compute_nearest_neighbors`` method docstring for URI
            specifications accepted.

            If the a descriptor index was configured and update was turned on,
            we add the computed descriptor to the index.

            JSON Return format::
                {
                    "success": <bool>

                    "message": <str>

                    "descriptor": <None|list[float]>

                    "reference_uri": <str>
                }

            :param uri: URI data specification.

            """
            descriptor = None
            try:
                descriptor = self.generate_descriptor_for_uri(uri)
                message = "Descriptor generated"
                descriptor = list(map(float, descriptor.vector()))
            except ValueError as ex:
                message = "Input value issue: %s" % str(ex)
            except RuntimeError as ex:
                message = "Descriptor extraction failure: %s" % str(ex)

            return flask.jsonify(
                success=descriptor is not None,
                message=message,
                descriptor=descriptor,
                reference_uri=uri,
            )

        @self.route("/nn/<path:uri>")
        @self.route("/nn/n=<int:n>/<path:uri>")
        @self.route("/nn/n=<int:n>/<int:start_i>:<int:end_i>/<path:uri>")
        def compute_nearest_neighbors(uri, n=10, start_i=None, end_i=None):
            """
            Data modes for upload/use:

                - local filepath
                - base64
                - http/s URL
                - existing data/descriptor UUID

            The following sub-sections detail how different URI's can be used.

            Local Filepath
            --------------
            The URI string must be prefixed with ``file://``, followed by the
            full path to the data file to describe.

            Base 64 data
            ------------
            The URI string must be prefixed with "base64://", followed by the
            base64 encoded string. This mode also requires an additional
            ``?content_type=`` to provide data content type information. This
            mode saves the encoded data to temporary file for processing.

            HTTP/S address
            --------------
            This is the default mode when the URI prefix is none of the above.
            This uses the requests module to locally download a data file
            for processing.

            Existing Data/Descriptor by UUID
            --------------------------------
            When given a uri prefixed with "uuid://", we interpret the remainder
            of the uri as the UUID of a descriptor already present in the
            configured descriptor index. If the given UUID is not present in the
            index, a KeyError is raised.

            JSON Return format
            ------------------
                {
                    "success": <bool>

                    "message": <str>

                    "neighbors": <None|list[float]>

                    "reference_uri": <str>
                }

            :param n: Number of neighbors to query for
            :param start_i: The starting index of the neighbor vectors to slice
                into for return.
            :param end_i: The ending index of the neighbor vectors to slice
                into for return.
            :type uri: str

            """
            descriptor = None
            try:
                descriptor = self.generate_descriptor_for_uri(uri)
                message = "descriptor computed"
            except ValueError as ex:
                message = "Input data issue: %s" % str(ex)
            except RuntimeError as ex:
                message = "Descriptor generation failure: %s" % str(ex)

            # Base pagination slicing based on provided start and end indices,
            # otherwise clamp to beginning/ending of queried neighbor sequence.
            page_slice = slice(start_i or 0, end_i or n)
            neighbors = []
            dists = []
            if descriptor is not None:
                try:
                    neighbors, dists = \
                        self.nn_index.nn(descriptor, n)
                except ValueError as ex:
                    message = "Descriptor or index related issue: %s" % str(ex)

            # TODO: Return the optional descriptor vectors for the neighbors
            # noinspection PyTypeChecker
            d = {
                "success": bool(descriptor is not None),
                "message": message,
                "neighbors": [n.uuid() for n in neighbors[page_slice]],
                "distances": dists[page_slice],
                "reference_uri": uri
            }
            return flask.jsonify(d)