示例#1
0
    def hack_remove_pystuff(self):
        import utool as ut
        # Hack of a method
        new_lines = []
        for lines in self.found_lines_list:
            # remove comment results
            flags = [not line.strip().startswith('# ') for line in lines]
            lines = ut.compress(lines, flags)

            # remove doctest results
            flags = [not line.strip().startswith('>>> ') for line in lines]
            lines = ut.compress(lines, flags)

            # remove cmdline tests
            import re
            flags = [
                not re.search('--test-' + self.extended_regex_list[0], line)
                for line in lines
            ]
            lines = ut.compress(lines, flags)

            flags = [
                not re.search('--exec-' + self.extended_regex_list[0], line)
                for line in lines
            ]
            lines = ut.compress(lines, flags)

            flags = [
                not re.search(
                    '--exec-[a-zA-z]*\.' + self.extended_regex_list[0], line)
                for line in lines
            ]
            lines = ut.compress(lines, flags)

            flags = [
                not re.search(
                    '--test-[a-zA-z]*\.' + self.extended_regex_list[0], line)
                for line in lines
            ]
            lines = ut.compress(lines, flags)

            # remove func defs
            flags = [
                not re.search('def ' + self.extended_regex_list[0], line)
                for line in lines
            ]
            lines = ut.compress(lines, flags)
            new_lines += [lines]
        self.found_lines_list = new_lines

        # compress self
        flags = [len(lines_) > 0 for lines_ in self.found_lines_list]
        idxs = ut.list_where(ut.not_list(flags))
        del self[idxs]
示例#2
0
 def handle_cache_misses(ibs, getter_func, rowid_list, ismiss_list, vals_list, cache_, kwargs):
     miss_indices = ut.list_where(ismiss_list)
     miss_rowids = ut.compress(rowid_list, ismiss_list)
     # call wrapped function
     miss_vals = getter_func(ibs, miss_rowids, **kwargs)
     # overwrite missed output
     for index, val in zip(miss_indices, miss_vals):
         vals_list[index] = val  # Output write
     # cache save
     for rowid, val in zip(miss_rowids, miss_vals):
         cache_[rowid] = val  # Cache write
示例#3
0
 def handle_cache_misses(ibs, getter_func, rowid_list, ismiss_list, vals_list, cache_, kwargs):
     miss_indices = ut.list_where(ismiss_list)
     miss_rowids  = ut.compress(rowid_list, ismiss_list)
     # call wrapped function
     miss_vals = getter_func(ibs, miss_rowids, **kwargs)
     # overwrite missed output
     for index, val in zip(miss_indices, miss_vals):
         vals_list[index] = val  # Output write
     # cache save
     for rowid, val in zip(miss_rowids, miss_vals):
         cache_[rowid] = val     # Cache write
示例#4
0
            def wrp_getter_cacher(ibs, rowid_list, **kwargs):
                """
                Wrapper function that caches rowid values in a dictionary
                """
                # HACK TAKE OUT GETTING DEBUG OUT OF KWARGS
                debug_ = kwargs.pop('debug', False)
                if cfgkeys is not None:
                    #kwargs_hash = ut.get_dict_hashid(ut.dict_take_list(kwargs, cfgkeys, None))
                    kwargs_hash = ut.get_dict_hashid(
                        [kwargs.get(key, None) for key in cfgkeys])
                    #ut.dict_take_list(kwargs, cfgkeys, None))
                else:
                    kwargs_hash = None
                #+----------------------------
                # There are 3 levels of caches
                #+----------------------------
                # All caches for this table
                #colscache_ = ibs.table_cache[tblname]
                ## All caches for the this column
                #kwargs_cache_ = colscache_[colname]
                ## All caches for this kwargs configuration
                #cache_ = kwargs_cache_[kwargs_hash]
                cache_ = ibs.table_cache[tblname][colname][kwargs_hash]
                #L____________________________

                # Load cached values for each rowid
                #vals_list = ut.dict_take_list(cache_, rowid_list, None)
                vals_list = [cache_.get(rowid, None) for rowid in rowid_list]
                # Mark rowids with cache misses
                ismiss_list = [val is None for val in vals_list]
                if debug or debug_:
                    debug_cache_hits(ismiss_list, rowid_list)
                    #print('[cache_getter] "debug_cache_hits" turned off')
                # HACK !!! DEBUG THESE GETTERS BY ASSERTING INFORMATION IN CACHE IS CORRECT
                if ASSERT_API_CACHE:
                    assert_cache_hits(ibs, ismiss_list, rowid_list,
                                      kwargs_hash, **kwargs)
                # END HACK
                if any(ismiss_list):
                    miss_indices = ut.list_where(ismiss_list)
                    miss_rowids = ut.compress(rowid_list, ismiss_list)
                    # call wrapped function
                    miss_vals = getter_func(ibs, miss_rowids, **kwargs)
                    # overwrite missed output
                    for index, val in zip(miss_indices, miss_vals):
                        vals_list[index] = val  # Output write
                    # cache save
                    for rowid, val in zip(miss_rowids, miss_vals):
                        cache_[rowid] = val  # Cache write
                return vals_list
示例#5
0
            def wrp_getter_cacher(ibs, rowid_list, **kwargs):
                """
                Wrapper function that caches rowid values in a dictionary
                """
                # HACK TAKE OUT GETTING DEBUG OUT OF KWARGS
                debug_ = kwargs.pop("debug", False)
                if cfgkeys is not None:
                    # kwargs_hash = ut.get_dict_hashid(ut.dict_take_list(kwargs, cfgkeys, None))
                    kwargs_hash = ut.get_dict_hashid([kwargs.get(key, None) for key in cfgkeys])
                    # ut.dict_take_list(kwargs, cfgkeys, None))
                else:
                    kwargs_hash = None
                # +----------------------------
                # There are 3 levels of caches
                # +----------------------------
                # All caches for this table
                # colscache_ = ibs.table_cache[tblname]
                ## All caches for the this column
                # kwargs_cache_ = colscache_[colname]
                ## All caches for this kwargs configuration
                # cache_ = kwargs_cache_[kwargs_hash]
                cache_ = ibs.table_cache[tblname][colname][kwargs_hash]
                # L____________________________

                # Load cached values for each rowid
                # vals_list = ut.dict_take_list(cache_, rowid_list, None)
                vals_list = [cache_.get(rowid, None) for rowid in rowid_list]
                # Mark rowids with cache misses
                ismiss_list = [val is None for val in vals_list]
                if debug or debug_:
                    debug_cache_hits(ismiss_list, rowid_list)
                    # print('[cache_getter] "debug_cache_hits" turned off')
                # HACK !!! DEBUG THESE GETTERS BY ASSERTING INFORMATION IN CACHE IS CORRECT
                if ASSERT_API_CACHE:
                    assert_cache_hits(ibs, ismiss_list, rowid_list, kwargs_hash, **kwargs)
                # END HACK
                if any(ismiss_list):
                    miss_indices = ut.list_where(ismiss_list)
                    miss_rowids = ut.compress(rowid_list, ismiss_list)
                    # call wrapped function
                    miss_vals = getter_func(ibs, miss_rowids, **kwargs)
                    # overwrite missed output
                    for index, val in zip(miss_indices, miss_vals):
                        vals_list[index] = val  # Output write
                    # cache save
                    for rowid, val in zip(miss_rowids, miss_vals):
                        cache_[rowid] = val  # Cache write
                return vals_list
示例#6
0
    def hack_remove_pystuff(self):
        import utool as ut
        # Hack of a method
        new_lines = []
        for lines in self.found_lines_list:
            # remove comment results
            flags = [not line.strip().startswith('# ') for line in lines]
            lines = ut.compress(lines, flags)

            # remove doctest results
            flags = [not line.strip().startswith('>>> ') for line in lines]
            lines = ut.compress(lines, flags)

            # remove cmdline tests
            import re
            flags = [not re.search('--test-' + self.extended_regex_list[0], line) for line in lines]
            lines = ut.compress(lines, flags)

            flags = [not re.search('--exec-' + self.extended_regex_list[0], line) for line in lines]
            lines = ut.compress(lines, flags)

            flags = [not re.search('--exec-[a-zA-z]*\.' + self.extended_regex_list[0], line) for line in lines]
            lines = ut.compress(lines, flags)

            flags = [not re.search('--test-[a-zA-z]*\.' + self.extended_regex_list[0], line) for line in lines]
            lines = ut.compress(lines, flags)

            # remove func defs
            flags = [not re.search('def ' + self.extended_regex_list[0], line) for line in lines]
            lines = ut.compress(lines, flags)
            new_lines += [lines]
        self.found_lines_list = new_lines

        # compress self
        flags = [len(lines_) > 0 for lines_ in self.found_lines_list]
        idxs = ut.list_where(ut.not_list(flags))
        del self[idxs]
def add_parts(ibs,
              aid_list,
              bbox_list=None,
              theta_list=None,
              detect_confidence_list=None,
              notes_list=None,
              vert_list=None,
              part_uuid_list=None,
              viewpoint_list=None,
              quality_list=None,
              type_list=None,
              staged_uuid_list=None,
              staged_user_id_list=None,
              **kwargs):
    r"""
    Adds an part to annotations

    Args:
        aid_list                 (list): annotation rowids to add part to
        bbox_list                (list): of [x, y, w, h] bounding boxes for each annotation (supply verts instead)
        theta_list               (list): orientations of parts
        vert_list                (list): alternative to bounding box

    Returns:
        list: part_rowid_list

    Ignore:
       detect_confidence_list = None
       notes_list = None
       part_uuid_list = None
       viewpoint_list = None
       quality_list = None
       type_list = None

    RESTful:
        Method: POST
        URL:    /api/part/
    """
    # ut.embed()
    from vtool import geometry

    if ut.VERBOSE:
        logger.info('[ibs] adding parts')
    # Prepare the SQL input
    # For import only, we can specify both by setting import_override to True
    assert bool(bbox_list is None) != bool(
        vert_list is None
    ), 'must specify exactly one of bbox_list or vert_list'
    ut.assert_all_not_None(aid_list, 'aid_list')

    if vert_list is None:
        vert_list = geometry.verts_list_from_bboxes_list(bbox_list)
    elif bbox_list is None:
        bbox_list = geometry.bboxes_from_vert_list(vert_list)

    if theta_list is None:
        theta_list = [0.0 for _ in range(len(aid_list))]

    len_bbox = len(bbox_list)
    len_vert = len(vert_list)
    len_aid = len(aid_list)
    len_theta = len(theta_list)
    try:
        assert len_vert == len_bbox, 'bbox and verts are not of same size'
        assert len_aid == len_bbox, 'bbox and aid are not of same size'
        assert len_aid == len_theta, 'bbox and aid are not of same size'
    except AssertionError as ex:
        ut.printex(ex,
                   key_list=['len_vert', 'len_aid', 'len_bbox'
                             'len_theta'])
        raise

    if len(aid_list) == 0:
        # nothing is being added
        logger.info('[ibs] WARNING: 0 parts are being added!')
        logger.info(ut.repr2(locals()))
        return []

    if detect_confidence_list is None:
        detect_confidence_list = [0.0 for _ in range(len(aid_list))]
    if notes_list is None:
        notes_list = ['' for _ in range(len(aid_list))]
    if viewpoint_list is None:
        viewpoint_list = [-1.0] * len(aid_list)
    if type_list is None:
        type_list = [const.UNKNOWN] * len(aid_list)

    nVert_list = [len(verts) for verts in vert_list]
    vertstr_list = [six.text_type(verts) for verts in vert_list]
    xtl_list, ytl_list, width_list, height_list = list(zip(*bbox_list))
    assert len(nVert_list) == len(vertstr_list)

    # Build ~~deterministic?~~ random and unique PART ids
    if part_uuid_list is None:
        part_uuid_list = [uuid.uuid4() for _ in range(len(aid_list))]

    if staged_uuid_list is None:
        staged_uuid_list = [None] * len(aid_list)
    is_staged_list = [
        staged_uuid is not None for staged_uuid in staged_uuid_list
    ]
    if staged_user_id_list is None:
        staged_user_id_list = [None] * len(aid_list)

    # Define arguments to insert
    colnames = (
        'part_uuid',
        'annot_rowid',
        'part_xtl',
        'part_ytl',
        'part_width',
        'part_height',
        'part_theta',
        'part_num_verts',
        'part_verts',
        'part_viewpoint',
        'part_detect_confidence',
        'part_note',
        'part_type',
        'part_staged_flag',
        'part_staged_uuid',
        'part_staged_user_identity',
    )

    check_uuid_flags = [
        not isinstance(auuid, uuid.UUID) for auuid in part_uuid_list
    ]
    if any(check_uuid_flags):
        pos = ut.list_where(check_uuid_flags)
        raise ValueError('positions %r have malformated UUIDS' % (pos, ))

    params_iter = list(
        zip(
            part_uuid_list,
            aid_list,
            xtl_list,
            ytl_list,
            width_list,
            height_list,
            theta_list,
            nVert_list,
            vertstr_list,
            viewpoint_list,
            detect_confidence_list,
            notes_list,
            type_list,
            is_staged_list,
            staged_uuid_list,
            staged_user_id_list,
        ))

    # Execute add PARTs SQL
    superkey_paramx = (0, )
    get_rowid_from_superkey = ibs.get_part_rowids_from_uuid
    part_rowid_list = ibs.db.add_cleanly(const.PART_TABLE, colnames,
                                         params_iter, get_rowid_from_superkey,
                                         superkey_paramx)
    return part_rowid_list
示例#8
0
def ingest_serengeti_mamal_cameratrap(species):
    """
    Downloads data from Serengeti dryad server

    References:
        http://datadryad.org/resource/doi:10.5061/dryad.5pt92
        Swanson AB, Kosmala M, Lintott CJ, Simpson RJ, Smith A, Packer C (2015)
        Snapshot Serengeti, high-frequency annotated camera trap images of 40
        mammalian species in an African savanna. Scientific Data 2: 150026.
        http://dx.doi.org/10.1038/sdata.2015.26
        Swanson AB, Kosmala M, Lintott CJ, Simpson RJ, Smith A, Packer C (2015)
        Data from: Snapshot Serengeti, high-frequency annotated camera trap
        images of 40 mammalian species in an African savanna. Dryad Digital
        Repository. http://dx.doi.org/10.5061/dryad.5pt92

    Args:
        species (?):

    CommandLine:
        python -m ibeis.dbio.ingest_database --test-ingest_serengeti_mamal_cameratrap --species zebra_plains
        python -m ibeis.dbio.ingest_database --test-ingest_serengeti_mamal_cameratrap --species cheetah

    Example:
        >>> # SCRIPT
        >>> from ibeis.dbio.ingest_database import *  # NOQA
        >>> import ibeis
        >>> species = ut.get_argval('--species', type_=str, default=ibeis.const.TEST_SPECIES.ZEB_PLAIN)
        >>> # species = ut.get_argval('--species', type_=str, default='cheetah')
        >>> result = ingest_serengeti_mamal_cameratrap(species)
        >>> print(result)
    """
    'https://snapshotserengeti.s3.msi.umn.edu/'
    import ibeis

    if species is None:
        code = 'ALL'
    elif species == 'zebra_plains':
        code = 'PZ'
    elif species == 'cheetah':
        code = 'CHTH'
    else:
        raise NotImplementedError()

    if species == 'zebra_plains':
        serengeti_sepcies = 'zebra'
    else:
        serengeti_sepcies = species

    print('species = %r' % (species,))
    print('serengeti_sepcies = %r' % (serengeti_sepcies,))

    dbname = code + '_Serengeti'
    print('dbname = %r' % (dbname,))
    dbdir = ut.ensuredir(join(ibeis.sysres.get_workdir(), dbname))
    print('dbdir = %r' % (dbdir,))
    image_dir = ut.ensuredir(join(dbdir, 'images'))

    base_url = 'http://datadryad.org/bitstream/handle/10255'
    all_images_url         = base_url + '/dryad.86392/all_images.csv'
    consensus_metadata_url = base_url + '/dryad.86348/consensus_data.csv'
    search_effort_url      = base_url + '/dryad.86347/search_effort.csv'
    gold_standard_url      = base_url + '/dryad.76010/gold_standard_data.csv'

    all_images_fpath         = ut.grab_file_url(all_images_url, download_dir=dbdir)
    consensus_metadata_fpath = ut.grab_file_url(consensus_metadata_url, download_dir=dbdir)
    search_effort_fpath      = ut.grab_file_url(search_effort_url, download_dir=dbdir)
    gold_standard_fpath      = ut.grab_file_url(gold_standard_url, download_dir=dbdir)

    print('all_images_fpath         = %r' % (all_images_fpath,))
    print('consensus_metadata_fpath = %r' % (consensus_metadata_fpath,))
    print('search_effort_fpath      = %r' % (search_effort_fpath,))
    print('gold_standard_fpath      = %r' % (gold_standard_fpath,))

    def read_csv(csv_fpath):
        import utool as ut
        csv_text = ut.read_from(csv_fpath)
        csv_lines = csv_text.split('\n')
        print(ut.list_str(csv_lines[0:2]))
        csv_data = [[field.strip('"').strip('\r') for field in line.split(',')]
                    for line in csv_lines if len(line) > 0]
        csv_header = csv_data[0]
        csv_data = csv_data[1:]
        return csv_data, csv_header

    def download_image_urls(image_url_info_list):
        # Find ones that we already have
        print('Requested %d downloaded images' % (len(image_url_info_list)))
        full_gpath_list = [join(image_dir, basename(gpath)) for gpath in image_url_info_list]
        exists_list = [ut.checkpath(gpath) for gpath in full_gpath_list]
        image_url_info_list_ = ut.compress(image_url_info_list, ut.not_list(exists_list))
        print('Already have %d/%d downloaded images' % (
            len(image_url_info_list) - len(image_url_info_list_), len(image_url_info_list)))
        print('Need to download %d images' % (len(image_url_info_list_)))
        #import sys
        #sys.exit(0)
        # Download the rest
        imgurl_prefix = 'https://snapshotserengeti.s3.msi.umn.edu/'
        image_url_list = [imgurl_prefix + suffix for suffix in image_url_info_list_]
        for img_url in ut.ProgressIter(image_url_list, lbl='Downloading image'):
            ut.grab_file_url(img_url, download_dir=image_dir)
        return full_gpath_list

    # Data contains information about which events have which animals
    if False:
        species_class_csv_data, species_class_header = read_csv(gold_standard_fpath)
        species_class_eventid_list    = ut.get_list_column(species_class_csv_data, 0)
        #gold_num_species_annots_list = ut.get_list_column(gold_standard_csv_data, 2)
        species_class_species_list    = ut.get_list_column(species_class_csv_data, 2)
        #gold_count_list              = ut.get_list_column(gold_standard_csv_data, 3)
    else:
        species_class_csv_data, species_class_header = read_csv(consensus_metadata_fpath)
        species_class_eventid_list    = ut.get_list_column(species_class_csv_data, 0)
        species_class_species_list    = ut.get_list_column(species_class_csv_data, 7)

    # Find the zebra events
    serengeti_sepcies_set = sorted(list(set(species_class_species_list)))
    print('serengeti_sepcies_hist = %s' %
          ut.dict_str(ut.dict_hist(species_class_species_list), key_order_metric='val'))
    #print('serengeti_sepcies_set = %s' % (ut.list_str(serengeti_sepcies_set),))

    assert serengeti_sepcies in serengeti_sepcies_set, 'not a known  seregeti species'
    species_class_chosen_idx_list = ut.list_where(
        [serengeti_sepcies == species_ for species_ in species_class_species_list])
    chosen_eventid_list = ut.take(species_class_eventid_list, species_class_chosen_idx_list)

    print('Number of chosen species:')
    print(' * len(species_class_chosen_idx_list) = %r' % (len(species_class_chosen_idx_list),))
    print(' * len(chosen_eventid_list) = %r' % (len(chosen_eventid_list),))

    # Read info about which events have which images
    images_csv_data, image_csv_header = read_csv(all_images_fpath)
    capture_event_id_list = ut.get_list_column(images_csv_data, 0)
    image_url_info_list = ut.get_list_column(images_csv_data, 1)
    # Group photos by eventid
    eventid_to_photos = ut.group_items(image_url_info_list, capture_event_id_list)

    # Filter to only chosens
    unflat_chosen_url_infos = ut.dict_take(eventid_to_photos, chosen_eventid_list)
    chosen_url_infos = ut.flatten(unflat_chosen_url_infos)
    image_url_info_list = chosen_url_infos
    chosen_path_list = download_image_urls(chosen_url_infos)

    ibs = ibeis.opendb(dbdir=dbdir, allow_newdir=True)
    gid_list_ = ibs.add_images(chosen_path_list, auto_localize=False)  # NOQA

    # Attempt to automatically detect the annotations
    #aids_list = ibs.detect_random_forest(gid_list_, species)
    #aids_list

    #if False:
    #    # remove non-zebra photos
    #    from os.path import basename
    #    base_gname_list = list(map(basename, zebra_url_infos))
    #    all_gname_list = ut.list_images(image_dir)
    #    nonzebra_gname_list = ut.setdiff_ordered(all_gname_list, base_gname_list)
    #    nonzebra_gpath_list = ut.fnames_to_fpaths(nonzebra_gname_list, image_dir)
    #    ut.remove_fpaths(nonzebra_gpath_list)
    return ibs