예제 #1
0
파일: utils.py 프로젝트: nullaegy/pyFF
class LRUProxyDict(MutableMapping):

    def __init__(self, proxy, *args, **kwargs):
        self._proxy = proxy
        self._cache = LRUCache(**kwargs)

    def __contains__(self, item):
        return item in self._cache or item in self._proxy

    def __getitem__(self, item):
        if item is None:
            raise ValueError("None key")
        v = self._cache.get(item, None)
        if v is not None:
            return v
        v = self._proxy.get(item, None)
        if v is not None:
            self._cache[item] = v
        return v

    def __setitem__(self, key, value):
        self._proxy[key] = value
        self._cache[key] = value

    def __delitem__(self, key):
        self._proxy.pop(key, None)
        self._cache.pop(key, None)

    def __iter__(self):
        return self._proxy.__iter__()

    def __len__(self):
        return len(self._proxy)
예제 #2
0
class ShorteningCache:
    def __init__(self):
        self.clear()
    def clear(self):  # Used pretty much only for setup and for tests
        self.cache = LRUCache(CACHE_SIZE)
    def put(self, url):
        self.cache[url.shortname] = url
    def get(self, shortname):
        return self.cache.get(shortname)
예제 #3
0
파일: cache.py 프로젝트: kurtrwall/dorthy
class LRULocalBackend(CacheBackend):
    def __init__(self, arguments):
        maxsize = arguments.get("maxsize", 1024)
        ttl = arguments.get("ttl", None)
        if ttl:
            self.__cache = TTLCache(maxsize, ttl=ttl)
        else:
            self.__cache = LRUCache(maxsize)

    def get(self, key):
        return self.__cache.get(key, NO_VALUE)

    def set(self, key, value):
        self.__cache[key] = value

    def delete(self, key):
        del self.__cache[key]
예제 #4
0
파일: hpfeeds.py 프로젝트: Jc2k/ADBHoney
class HpfeedsLogger:
    """
    Log full complete sessions to hpfeeds.
    """
    def __init__(self, host, port, ident, secret, ssl=None):
        super().__init__(*args, **kwargs)
        self.sessions = LRUCache(1000)
        self.session = ClientSession(host, port, ident, secret, ssl)
        self.exit_stack = AsyncExitStack()

    async def __aenter__(self):
        await self.exit_stack.enter_async_context(self.session)
        return super().__aenter__()

    def log(self, event):
        session_id = event['session_id']

        if event['type'] == 'adbhoney.session.connect':
            self.sessions[session_id] = {
                'src_ip': event['src_ip'],
                'src_port': event['src_port'],
                'dst_ip': event['dst_ip'],
                'dst_port': event['dst_port'],
                'sensor': event['sensor'],
                'shasum': [],
            }
            return

        session = self.sessions.get(session_id, {})

        if event['type'] == 'adbhoney.session.file_upload':
            session['shasum'].append(event['shasum'])

        elif event['type'] == 'adbhoney.session.closed':
            session.update({
                'closedmessage': event['closedmessage'],
                'duration': event['duration'],
            })

            try:
                self.publish('adbhoney', json.dumps(event))
            finally:
                self.sessions.pop(session_id, None)
예제 #5
0
class LRUEngineCache(DataKeyedCache):
    """!
     An instance of DataKeyed cache using a least recently used (LRU) method
    """
    def __init__(self, size=1000):
        """!
        Constructor for LRUCache
        @type size: int
        @param size: maximum entries in the cache
        """

        self.cache = LRUCache(maxsize=size)

    def get_cache_value(self, cache_key):

        return self.cache.get(cache_key)

    def set_cache_value(self, key, value):
        self.cache.__setitem__(key, value)
예제 #6
0
class ResolvedContext:
    """
    A cached contex document, with a cache indexed by referencing active context.
    """
    def __init__(self, document):
        """
        Creates a ResolvedContext with caching for processed contexts
        relative to some other Active Context.
        """
        # processor-specific RDF parsers
        self.document = document
        self.cache = LRUCache(maxsize=MAX_ACTIVE_CONTEXTS)

    def get_processed(self, active_ctx):
        """
        Returns any processed context for this resolved context relative to an active context.
        """
        return self.cache.get(active_ctx['_uuid'])

    def set_processed(self, active_ctx, processed_ctx):
        """
        Sets any processed context for this resolved context relative to an active context.
        """
        self.cache[active_ctx['_uuid']] = processed_ctx
예제 #7
0
class ProxyResolver(BaseResolver):
    _Record = recordclass('_Record', 'expire data')

    def __init__(self, upstreams, cache_size=None):
        super().__init__()
        self._upstreams = upstreams

        self._cache = LRUCache(cache_size) if cache_size else None
        self._cache_lock = RLock()

    def _query_cache(self, key):
        with self._cache_lock:
            records = self._cache.get(key, None)
            if records is None:
                records = []
            else:
                records = [record for record in records if record.expire > time()]
                if len(records):
                    records = records[1:] + [records[0]]
                    self._cache[key] = records
                else:
                    del self._cache[key]
            return records

    def _add_to_cache(self, key, record):
        if self._cache is None:
            return

        with self._cache_lock:
            records = self._cache.get(key, None)
            if records is None:
                self._cache[key] = [record]
            else:
                for erecord in records:
                    if erecord.data == record.data:
                        erecord.expire = record.expire
                        break
                else:
                    records.append(record)

    def _resolve_in_cache(self, questions, oq, oa, now):
        for q in questions:
            key = (q.qname, QTYPE.CNAME, q.qclass)
            cnames = self._query_cache(key)
            if len(cnames):
                recursive_questions = []
                for cname in cnames:
                    oa.add_answer(RR(ttl=max(cname.expire - now, 0), **cname.data))
                    recursive_questions.append(DNSQuestion(
                        qname=cname.data['rdata'].label,
                        qtype=q.qtype,
                        qclass=q.qclass
                    ))

                    self._resolve_in_cache(recursive_questions, oq, oa, now)
            else:
                if q.qtype != QTYPE.CNAME:
                    key = (q.qname, q.qtype, q.qclass)
                    record_list = self._query_cache(key)
                    if len(record_list):
                        for record in record_list:
                            oa.add_answer(RR(ttl=max(record.expire - now, 0), **record.data))
                    else:
                        oq.add_question(q)
                else:
                    oq.add_question(q)

    def resolve(self, request, handler):
        now = int(time())
        a = request.reply()

        uq = DNSRecord()
        self._resolve_in_cache(request.questions, uq, a, now)

        if len(uq.questions):
            for upstream in self._upstreams:
                try:
                    ua_pkt = uq.send(
                        str(upstream.address),
                        upstream.port,
                        upstream.tcp,
                        upstream.timeout,
                        upstream.ipv6
                    )
                    ua = DNSRecord.parse(ua_pkt)
                except:
                    continue

                for rr in ua.rr:
                    key = (rr.rname, rr.rtype, rr.rclass)
                    cr = self._Record(now + rr.ttl, {
                        'rname': rr.rname,
                        'rtype': rr.rtype,
                        'rclass': rr.rclass,
                        'rdata': rr.rdata,
                    })
                    self._add_to_cache(key, cr)
                a.add_answer(*ua.rr)
                break
            else:
                raise IOError

        return a
class MetricCassandraRepository(abstract_repository.AbstractCassandraRepository):
    def __init__(self):
        super(MetricCassandraRepository, self).__init__()

        self._lock = threading.RLock()

        LOG.debug("prepare cql statements...")

        self._measurement_insert_stmt = self._session.prepare(MEASUREMENT_INSERT_CQL)
        self._measurement_insert_stmt.is_idempotent = True

        self._measurement_update_stmt = self._session.prepare(MEASUREMENT_UPDATE_CQL)
        self._measurement_update_stmt.is_idempotent = True

        self._metric_insert_stmt = self._session.prepare(METRICS_INSERT_CQL)
        self._metric_insert_stmt.is_idempotent = True

        self._metric_update_stmt = self._session.prepare(METRICS_UPDATE_CQL)
        self._metric_update_stmt.is_idempotent = True

        self._dimension_stmt = self._session.prepare(DIMENSION_INSERT_CQL)
        self._dimension_stmt.is_idempotent = True

        self._dimension_metric_stmt = self._session.prepare(DIMENSION_METRIC_INSERT_CQL)
        self._dimension_metric_stmt.is_idempotent = True

        self._metric_dimension_stmt = self._session.prepare(METRIC_DIMENSION_INSERT_CQL)
        self._metric_dimension_stmt.is_idempotent = True

        self._retrieve_metric_dimension_stmt = self._session.prepare(RETRIEVE_METRIC_DIMENSION_CQL)

        self._metric_batch = MetricBatch(
            self._cluster.metadata,
            self._cluster.load_balancing_policy,
            self._max_batches)

        self._metric_id_cache = LRUCache(self._cache_size)
        self._dimension_cache = LRUCache(self._cache_size)
        self._metric_dimension_cache = LRUCache(self._cache_size)

        self._load_dimension_cache()
        self._load_metric_dimension_cache()

    def process_message(self, message):
        (dimensions, metric_name, region, tenant_id, time_stamp, value,
         value_meta) = parse_measurement_message(message)

        with self._lock:
            dim_names = []
            dim_list = []
            for name in sorted(dimensions.iterkeys()):
                dim_list.append('%s\t%s' % (name, dimensions[name]))
                dim_names.append(name)

            hash_string = '%s\0%s\0%s\0%s' % (region, tenant_id, metric_name, '\0'.join(dim_list))
            metric_id = hashlib.sha1(hash_string.encode('utf8')).hexdigest()

            metric = Metric(id=metric_id,
                            region=region,
                            tenant_id=tenant_id,
                            name=metric_name,
                            dimension_list=dim_list,
                            dimension_names=dim_names,
                            time_stamp=time_stamp,
                            value=value,
                            value_meta=json.dumps(value_meta, ensure_ascii=False))

            id_bytes = bytearray.fromhex(metric.id)
            if self._metric_id_cache.get(metric.id, None):
                measurement_bound_stmt = self._measurement_update_stmt.bind((self._retention,
                                                                             metric.value,
                                                                             metric.value_meta,
                                                                             id_bytes,
                                                                             metric.time_stamp))
                self._metric_batch.add_measurement_query(measurement_bound_stmt)

                metric_update_bound_stmt = self._metric_update_stmt.bind((self._retention,
                                                                          metric.time_stamp,
                                                                          metric.region,
                                                                          metric.tenant_id,
                                                                          metric.name,
                                                                          metric.dimension_list,
                                                                          metric.dimension_names))
                self._metric_batch.add_metric_query(metric_update_bound_stmt)

                return metric

            self._metric_id_cache[metric.id] = metric.id

            metric_insert_bound_stmt = self._metric_insert_stmt.bind((self._retention,
                                                                      id_bytes,
                                                                      metric.time_stamp,
                                                                      metric.time_stamp,
                                                                      metric.region,
                                                                      metric.tenant_id,
                                                                      metric.name,
                                                                      metric.dimension_list,
                                                                      metric.dimension_names))
            self._metric_batch.add_metric_query(metric_insert_bound_stmt)

            for dim in metric.dimension_list:
                (name, value) = dim.split('\t')
                dim_key = self._get_dimnesion_key(metric.region, metric.tenant_id, name, value)
                if not self._dimension_cache.get(dim_key, None):
                    dimension_bound_stmt = self._dimension_stmt.bind((metric.region,
                                                                      metric.tenant_id,
                                                                      name,
                                                                      value))
                    self._metric_batch.add_dimension_query(dimension_bound_stmt)
                    self._dimension_cache[dim_key] = dim_key

                metric_dim_key = self._get_metric_dimnesion_key(
                    metric.region, metric.tenant_id, metric.name, name, value)
                if not self._metric_dimension_cache.get(metric_dim_key, None):
                    dimension_metric_bound_stmt = self._dimension_metric_stmt.bind(
                        (metric.region, metric.tenant_id, name, value, metric.name))
                    self._metric_batch.add_dimension_metric_query(dimension_metric_bound_stmt)

                    metric_dimension_bound_stmt = self._metric_dimension_stmt.bind(
                        (metric.region, metric.tenant_id, metric.name, name, value))
                    self._metric_batch.add_metric_dimension_query(metric_dimension_bound_stmt)

                    self._metric_dimension_cache[metric_dim_key] = metric_dim_key

            measurement_insert_bound_stmt = self._measurement_insert_stmt.bind(
                (self._retention,
                 metric.value,
                 metric.value_meta,
                 metric.region,
                 metric.tenant_id,
                 metric.name,
                 metric.dimension_list,
                 id_bytes,
                 metric.time_stamp))
            self._metric_batch.add_measurement_query(measurement_insert_bound_stmt)

            return metric

    def write_batch(self, metrics):

        with self._lock:
            batch_list = self._metric_batch.get_all_batches()

            results = execute_concurrent(self._session, batch_list, raise_on_first_error=True)

            self._handle_results(results)

            self._metric_batch.clear()

            LOG.info("flushed %s metrics", len(metrics))

    @staticmethod
    def _handle_results(results):
        for (success, result) in results:
            if not success:
                raise result

    def _load_dimension_cache(self):

        rows = self._session.execute(RETRIEVE_DIMENSION_CQL)

        if not rows:
            return

        for row in rows:
            key = self._get_dimnesion_key(row.region, row.tenant_id, row.name, row.value)
            self._dimension_cache[key] = key

        LOG.info(
            "loaded %s dimension entries cache from database into cache." %
            self._dimension_cache.currsize)

    @staticmethod
    def _get_dimnesion_key(region, tenant_id, name, value):
        return '%s\0%s\0%s\0%s' % (region, tenant_id, name, value)

    def _load_metric_dimension_cache(self):
        qm = token_range_query_manager.TokenRangeQueryManager(RETRIEVE_METRIC_DIMENSION_CQL,
                                                              self._process_metric_dimension_query)

        token_ring = self._cluster.metadata.token_map.ring

        qm.query(token_ring)

    def _process_metric_dimension_query(self, rows):

        cnt = 0
        for row in rows:
            key = self._get_metric_dimnesion_key(
                row.region,
                row.tenant_id,
                row.metric_name,
                row.dimension_name,
                row.dimension_value)
            self._metric_dimension_cache[key] = key
            cnt += 1

        LOG.info("loaded %s metric dimension entries from database into cache." % cnt)
        LOG.info(
            "total loaded %s metric dimension entries in cache." %
            self._metric_dimension_cache.currsize)

    @staticmethod
    def _get_metric_dimnesion_key(region, tenant_id, metric_name, dimension_name, dimension_value):

        return '%s\0%s\0%s\0%s\0%s' % (region, tenant_id, metric_name,
                                       dimension_name, dimension_value)
예제 #9
0
class GMusicLibraryProvider(backend.LibraryProvider):
    root_directory = Ref.directory(uri='gmusic:directory', name='Google Music')

    def __init__(self, *args, **kwargs):
        super(GMusicLibraryProvider, self).__init__(*args, **kwargs)

        # tracks, albums, and artists here refer to what is explicitly
        # in our library.
        self.tracks = {}
        self.albums = {}
        self.artists = {}

        # aa_* caches are *only* used for temporary objects. Library
        # objects will never make it here.
        self.aa_artists = LRUCache(1024)
        self.aa_tracks = LRUCache(1024)
        self.aa_albums = LRUCache(1024)

        self._radio_stations_in_browse = (
            self.backend.config['gmusic']['radio_stations_in_browse'])
        self._radio_stations_count = (
            self.backend.config['gmusic']['radio_stations_count'])
        self._radio_tracks_count = (
            self.backend.config['gmusic']['radio_tracks_count'])

        self._top_tracks_count = (
            self.backend.config['gmusic']['top_tracks_count'])

        # Setup the root of library browsing.
        self._root = [
            Ref.directory(uri='gmusic:album', name='Albums'),
            Ref.directory(uri='gmusic:artist', name='Artists'),
            Ref.directory(uri='gmusic:track', name='Tracks')
        ]

        if self._radio_stations_in_browse:
            self._root.append(Ref.directory(uri='gmusic:radio',
                                            name='Radios'))

    @property
    def all_access(self):
        return self.backend.session.all_access

    def _browse_tracks(self):
        tracks = list(self.tracks.values())
        tracks.sort(key=lambda ref: ref.name)
        refs = []
        for track in tracks:
            refs.append(track_to_ref(track))
        return refs

    def _browse_albums(self):
        refs = []
        for album in self.albums.values():
            refs.append(album_to_ref(album))
        refs.sort(key=lambda ref: ref.name)
        return refs

    def _browse_album(self, uri):
        refs = []
        for track in self._lookup_album(uri):
            refs.append(track_to_ref(track, True))
        return refs

    def _browse_artists(self):
        refs = []
        for artist in self.artists.values():
            refs.append(artist_to_ref(artist))
        refs.sort(key=lambda ref: ref.name)
        return refs

    def _browse_artist(self, uri):
        refs = []
        for album in self._get_artist_albums(uri):
            refs.append(album_to_ref(album))
            refs.sort(key=lambda ref: ref.name)
        if len(refs) > 0:
            refs.insert(0, Ref.directory(uri=uri + ':all', name='All Tracks'))
            is_all_access = uri.startswith('gmusic:artist:A')
            if is_all_access:
                refs.insert(1, Ref.directory(uri=uri + ':top', name='Top Tracks'))
            return refs
        else:
            # Show all tracks if no album is available
            return self._browse_artist_all_tracks(uri)

    def _browse_artist_all_tracks(self, uri):
        artist_uri = ':'.join(uri.split(':')[:3])
        refs = []
        tracks = self._lookup_artist(artist_uri, True)
        for track in tracks:
            refs.append(track_to_ref(track))
        return refs

    def _browse_artist_top_tracks(self, uri):
        artist_uri = ':'.join(uri.split(':')[:3])
        refs = []
        tracks = self._get_artist_top_tracks(artist_uri)
        for track in tracks:
            refs.append(track_to_ref(track))
        return refs

    def _browse_radio_stations(self, uri):
        stations = self.backend.session.get_radio_stations(
            self._radio_stations_count)
        # create Ref objects
        refs = []
        for station in stations:
            refs.append(Ref.directory(uri='gmusic:radio:' + station['id'],
                                      name=station['name']))
        return refs

    def _browse_radio_station(self, uri):
        station_id = uri.split(':')[2]
        tracks = self.backend.session.get_station_tracks(
            station_id, self._radio_tracks_count)

        # create Ref objects
        refs = []
        for track in tracks:
            mopidy_track = self._to_mopidy_track(track)
            self.aa_tracks[mopidy_track.uri] = mopidy_track
            refs.append(track_to_ref(mopidy_track))
        return refs

    def browse(self, uri):
        logger.debug('browse: %s', str(uri))
        if not uri:
            return []
        if uri == self.root_directory.uri:
            return self._root

        parts = uri.split(':')

        # tracks
        if uri == 'gmusic:track':
            return self._browse_tracks()

        # albums
        if uri == 'gmusic:album':
            return self._browse_albums()

        # a single album
        # uri == 'gmusic:album:album_id'
        if len(parts) == 3 and parts[1] == 'album':
            return self._browse_album(uri)

        # artists
        if uri == 'gmusic:artist':
            return self._browse_artists()

        # a single artist
        # uri == 'gmusic:artist:artist_id'
        if len(parts) == 3 and parts[1] == 'artist':
            return self._browse_artist(uri)

        # all tracks of a single artist
        # uri == 'gmusic:artist:artist_id:all'
        if len(parts) == 4 and parts[1] == 'artist' and parts[3] == 'all':
            return self._browse_artist_all_tracks(uri)

        # top tracks of a single artist
        # uri == 'gmusic:artist:artist_id:top'
        if len(parts) == 4 and parts[1] == 'artist' and parts[3] == 'top':
            return self._browse_artist_top_tracks(uri)

        # all radio stations
        if uri == 'gmusic:radio':
            return self._browse_radio_stations(uri)

        # a single radio station
        # uri == 'gmusic:radio:station_id'
        if len(parts) == 3 and parts[1] == 'radio':
            return self._browse_radio_station(uri)

        logger.debug('Unknown uri for browse request: %s', uri)

        return []

    def lookup(self, uri):
        if uri.startswith('gmusic:track:'):
            return self._lookup_track(uri)
        elif uri.startswith('gmusic:album:'):
            return self._lookup_album(uri)
        elif uri.startswith('gmusic:artist:'):
            return self._lookup_artist(uri)
        else:
            return []

    def _lookup_track(self, uri):
        is_all_access = uri.startswith('gmusic:track:T')

        try:
            return [self.tracks[uri]]
        except KeyError:
            logger.debug('Track not a library track %r', uri)
            pass

        if is_all_access and self.all_access:
            track = self.aa_tracks.get(uri)
            if track:
                return [track]
            song = self.backend.session.get_track_info(uri.split(':')[2])
            if song is None:
                logger.warning('There is no song %r', uri)
                return []
            if 'artistId' not in song:
                logger.warning('Failed to lookup %r', uri)
                return []
            mopidy_track = self._to_mopidy_track(song)
            self.aa_tracks[mopidy_track.uri] = mopidy_track
            return [mopidy_track]
        else:
            return []

    def _lookup_album(self, uri):
        is_all_access = uri.startswith('gmusic:album:B')
        if self.all_access and is_all_access:
            tracks = self.aa_albums.get(uri)
            if tracks:
                return tracks
            album = self.backend.session.get_album_info(
                uri.split(':')[2], include_tracks=True)
            if album and album.get('tracks'):
                tracks = [self._to_mopidy_track(track)
                          for track in album['tracks']]
                for track in tracks:
                    self.aa_tracks[track.uri] = track
                tracks = sorted(tracks, key=lambda t: (t.disc_no, t.track_no))
                self.aa_albums[uri] = tracks
                return tracks

            logger.warning('Failed to lookup all access album %r: %r',
                           uri, album)

        # Even if the album has an all access ID, we need to look it
        # up here (as a fallback) because purchased tracks can have a
        # store ID, but only show up in your library.
        try:
            album = self.albums[uri]
        except KeyError:
            logger.debug('Failed to lookup %r', uri)
            return []

        tracks = self._find_exact(
            dict(album=album.name,
                 artist=[artist.name for artist in album.artists],
                 date=album.date)).tracks
        return sorted(tracks, key=lambda t: (t.disc_no, t.track_no))

    def _get_artist_top_tracks(self, uri):
        is_all_access = uri.startswith('gmusic:artist:A')
        artist_id = uri.split(':')[2]

        if not is_all_access:
            logger.debug("Top Tracks not available for non-all-access artists")
            return []

        artist_info = self.backend.session.get_artist_info(artist_id,
                                                           include_albums=False,
                                                           max_top_tracks=self._top_tracks_count,
                                                           max_rel_artist=0)
        top_tracks = []

        for track_dict in artist_info['topTracks']:
            top_tracks.append(self._to_mopidy_track(track_dict))

        return top_tracks

    def _get_artist_albums(self, uri):
        is_all_access = uri.startswith('gmusic:artist:A')

        artist_id = uri.split(':')[2]
        if is_all_access:
            # all access
            artist_infos = self.backend.session.get_artist_info(
                artist_id, max_top_tracks=0, max_rel_artist=0)
            if artist_infos is None or 'albums' not in artist_infos:
                return []
            albums = []
            for album in artist_infos['albums']:
                albums.append(
                    self._aa_search_album_to_mopidy_album({'album': album}))
            return albums
        elif self.all_access and artist_id in self.aa_artists:
            albums = self._get_artist_albums(
                'gmusic:artist:%s' % self.aa_artists[artist_id])
            if len(albums) > 0:
                return albums
            # else fall back to non aa albums
        if uri in self.artists:
            artist = self.artists[uri]
            return [album for album in self.albums.values()
                    if artist in album.artists]
        else:
            logger.debug('0 albums available for artist %r', uri)
            return []

    def _lookup_artist(self, uri, exact_match=False):
        def sorter(track):
            return (
                track.album.date,
                track.album.name,
                track.disc_no,
                track.track_no,
            )

        if self.all_access:
            try:
                all_access_id = self.aa_artists[uri.split(':')[2]]
                artist_infos = self.backend.session.get_artist_info(
                    all_access_id, max_top_tracks=0, max_rel_artist=0)
                if not artist_infos or not artist_infos['albums']:
                    logger.warning('Failed to lookup %r', artist_infos)
                tracks = [
                    self._lookup_album('gmusic:album:' + album['albumId'])
                    for album in artist_infos['albums']]
                tracks = reduce(lambda a, b: (a + b), tracks)
                return sorted(tracks, key=sorter)
            except KeyError:
                pass
        try:
            artist = self.artists[uri]
        except KeyError:
            logger.debug('Failed to lookup %r', uri)
            return []

        tracks = self._find_exact(
            dict(artist=artist.name)).tracks
        if exact_match:
            tracks = filter(lambda t: artist in t.artists, tracks)
        return sorted(tracks, key=sorter)

    def refresh(self, uri=None):
        self.tracks = {}
        self.albums = {}
        self.artists = {}

        album_tracks = {}
        for track in self.backend.session.get_all_songs():
            mopidy_track = self._to_mopidy_track(track)

            self.tracks[mopidy_track.uri] = mopidy_track
            self.albums[mopidy_track.album.uri] = mopidy_track.album

            # We don't care about the order because we're just using
            # this as a temporary variable to grab the proper album
            # artist out of the album.
            if mopidy_track.album.uri not in album_tracks:
                album_tracks[mopidy_track.album.uri] = []

            album_tracks[mopidy_track.album.uri].append(mopidy_track)

        # Yes, this is awful. No, I don't have a better solution. Yes,
        # I'm annoyed at Google for not providing album artist IDs.
        for album in self.albums.values():
            artist_found = False
            for album_artist in album.artists:
                for track in album_tracks[album.uri]:
                    for artist in track.artists:
                        if album_artist.name == artist.name:
                            artist_found = True
                            self.artists[artist.uri] = artist

            if not artist_found:
                for artist in album.artists:
                    self.artists[artist.uri] = artist

    def search(self, query=None, uris=None, exact=False):
        if exact:
            return self._find_exact(query=query, uris=uris)

        lib_tracks, lib_artists, lib_albums = self._search_library(query, uris)

        if query:
            aa_tracks, aa_artists, aa_albums = self._search(query, uris)
            for aa_artist in aa_artists:
                lib_artists.add(aa_artist)

            for aa_album in aa_albums:
                lib_albums.add(aa_album)

            lib_tracks = set(lib_tracks)

            for aa_track in aa_tracks:
                lib_tracks.add(aa_track)

        return SearchResult(uri='gmusic:search',
                            tracks=lib_tracks,
                            artists=lib_artists,
                            albums=lib_albums)

    def _find_exact(self, query=None, uris=None):
        # Find exact can only be done on gmusic library,
        # since one can't filter all access searches
        lib_tracks, lib_artists, lib_albums = self._search_library(query, uris)

        return SearchResult(uri='gmusic:search',
                            tracks=lib_tracks,
                            artists=lib_artists,
                            albums=lib_albums)

    def _search(self, query=None, uris=None):
        for (field, values) in query.iteritems():
            if not hasattr(values, '__iter__'):
                values = [values]

            # Since gmusic does not support search filters, just search for the
            # first 'searchable' filter
            if field in [
                    'track_name', 'album', 'artist', 'albumartist', 'any']:
                logger.info(
                    'Searching Google Play Music for: %s',
                    values[0])
                res = self.backend.session.search(values[0], max_results=50)
                if res is None:
                    return [], [], []

                albums = [
                    self._aa_search_album_to_mopidy_album(album_res)
                    for album_res in res['album_hits']]
                artists = [
                    self._aa_search_artist_to_mopidy_artist(artist_res)
                    for artist_res in res['artist_hits']]
                tracks = [
                    self._aa_search_track_to_mopidy_track(track_res)
                    for track_res in res['song_hits']]

                return tracks, artists, albums

        return [], [], []

    def _search_library(self, query=None, uris=None):
        if query is None:
            query = {}
        self._validate_query(query)
        result_tracks = self.tracks.values()

        for (field, values) in query.iteritems():
            if not hasattr(values, '__iter__'):
                values = [values]
            # FIXME this is bound to be slow for large libraries
            for value in values:
                if field == 'track_no':
                    q = self._convert_to_int(value)
                else:
                    q = value.strip().lower()

                def uri_filter(track):
                    return q in track.uri.lower()

                def track_name_filter(track):
                    return q in track.name.lower()

                def album_filter(track):
                    return q in getattr(track, 'album', Album()).name.lower()

                def artist_filter(track):
                    return (
                        any(q in a.name.lower() for a in track.artists) or
                        albumartist_filter(track))

                def albumartist_filter(track):
                    album_artists = getattr(track, 'album', Album()).artists
                    return any(q in a.name.lower() for a in album_artists)

                def track_no_filter(track):
                    return track.track_no == q

                def date_filter(track):
                    return track.date and track.date.startswith(q)

                def any_filter(track):
                    return any([
                        uri_filter(track),
                        track_name_filter(track),
                        album_filter(track),
                        artist_filter(track),
                        albumartist_filter(track),
                        date_filter(track),
                    ])

                if field == 'uri':
                    result_tracks = filter(uri_filter, result_tracks)
                elif field == 'track_name':
                    result_tracks = filter(track_name_filter, result_tracks)
                elif field == 'album':
                    result_tracks = filter(album_filter, result_tracks)
                elif field == 'artist':
                    result_tracks = filter(artist_filter, result_tracks)
                elif field == 'albumartist':
                    result_tracks = filter(albumartist_filter, result_tracks)
                elif field == 'track_no':
                    result_tracks = filter(track_no_filter, result_tracks)
                elif field == 'date':
                    result_tracks = filter(date_filter, result_tracks)
                elif field == 'any':
                    result_tracks = filter(any_filter, result_tracks)
                else:
                    raise LookupError('Invalid lookup field: %s' % field)

        result_artists = set()
        result_albums = set()
        for track in result_tracks:
            result_artists |= track.artists
            result_albums.add(track.album)

        return result_tracks, result_artists, result_albums

    def _validate_query(self, query):
        for (_, values) in query.iteritems():
            if not values:
                raise LookupError('Missing query')
            for value in values:
                if not value:
                    raise LookupError('Missing query')

    def _to_mopidy_track(self, song):
        track_id = song.get('id', song.get('nid'))
        if track_id is None:
            raise ValueError
        if track_id[0] != "T" and "-" not in track_id:
            track_id = "T"+track_id
        return Track(
            uri='gmusic:track:' + track_id,
            name=song['title'],
            artists=[self._to_mopidy_artist(song)],
            album=self._to_mopidy_album(song),
            track_no=song.get('trackNumber', 1),
            disc_no=song.get('discNumber', 1),
            date=unicode(song.get('year', 0)),
            length=int(song['durationMillis']),
            bitrate=320)

    def _to_mopidy_album(self, song):
        name = song.get('album', '')
        artist = self._to_mopidy_album_artist(song)
        date = unicode(song.get('year', 0))

        album_id = song.get('albumId')
        if album_id is None:
            album_id = create_id(artist.name + name + date)

        uri = 'gmusic:album:' + album_id
        images = get_images(song)
        return Album(
            uri=uri,
            name=name,
            artists=[artist],
            num_tracks=song.get('totalTrackCount'),
            num_discs=song.get('totalDiscCount'),
            date=date,
            images=images)

    def _to_mopidy_artist(self, song):
        name = song.get('artist', '')
        artist_id = song.get('artistId')
        if artist_id is not None:
            artist_id = artist_id[0]
        else:
            artist_id = create_id(name)
        uri = 'gmusic:artist:' + artist_id
        return Artist(uri=uri, name=name)

    def _to_mopidy_album_artist(self, song):
        name = song.get('albumArtist', '')
        if name.strip() == '':
            name = song.get('artist', '')
        uri = 'gmusic:artist:' + create_id(name)
        return Artist(uri=uri, name=name)

    def _aa_search_track_to_mopidy_track(self, search_track):
        track = search_track['track']

        aa_artist_id = create_id(track['artist'])
        if 'artistId' in track:
            aa_artist_id = track['artistId'][0]
        else:
            logger.warning('No artistId for Track %r', track)

        artist = Artist(
            uri='gmusic:artist:' + aa_artist_id,
            name=track['artist'])

        album = Album(
            uri='gmusic:album:' + track['albumId'],
            name=track['album'],
            artists=[artist],
            date=unicode(track.get('year', 0)))

        return Track(
            uri='gmusic:track:' + track['storeId'],
            name=track['title'],
            artists=[artist],
            album=album,
            track_no=track.get('trackNumber', 1),
            disc_no=track.get('discNumber', 1),
            date=unicode(track.get('year', 0)),
            length=int(track['durationMillis']),
            bitrate=320)

    def _aa_search_artist_to_mopidy_artist(self, search_artist):
        artist = search_artist['artist']
        uri = 'gmusic:artist:' + artist['artistId']
        return Artist(uri=uri, name=artist['name'])

    def _aa_search_album_to_mopidy_album(self, search_album):
        album = search_album['album']
        uri = 'gmusic:album:' + album['albumId']
        name = album['name']
        artist = self._aa_search_artist_album_to_mopidy_artist_album(album)
        date = unicode(album.get('year', 0))
        return Album(
            uri=uri,
            name=name,
            artists=[artist],
            date=date)

    def _aa_search_artist_album_to_mopidy_artist_album(self, album):
        name = album.get('albumArtist', '')
        if name.strip() == '':
            name = album.get('artist', '')
        uri = 'gmusic:artist:' + create_id(name)
        return Artist(uri=uri, name=name)

    def _convert_to_int(self, string):
        try:
            return int(string)
        except ValueError:
            return object()
예제 #10
0
class GoogleHangoutsChatBackend(ErrBot):
    def __init__(self, config):
        super().__init__(config)
        identity = config.BOT_IDENTITY
        self.at_name = config.BOT_PREFIX
        self.creds_file = identity['GOOGLE_CREDS_FILE']
        self.gce_project = identity['GOOGLE_CLOUD_ENGINE_PROJECT']
        self.gce_topic = identity['GOOGLE_CLOUD_ENGINE_PUBSUB_TOPIC']
        self.gce_subscription = identity['GOOGLE_CLOUD_ENGINE_PUBSUB_SUBSCRIPTION']
        self.chat_api = GoogleHangoutsChatAPI(self.creds_file)
        self.bot_identifier = HangoutsChatUser(None, self.at_name, None, None)
        self.event_cache = LRUCache(1024)
        self.md = hangoutschat_markdown_converter()

    def _subscribe_to_pubsub_topic(self, project, topic_name, subscription_name, callback):
        subscriber = pubsub.SubscriberClient()
        subscription_name = 'projects/{}/subscriptions/{}'.format(project, subscription_name)
        log.info("Subscribed to {}".format(subscription_name))
        return subscriber.subscribe(subscription_name, callback=callback)

    def _event_cache_format_key(self, event_data):
        event_time = event_data.get('eventTime', 0)
        if event_time == 0:
            log.warning('Received 0 eventTime from message')

        event_type = event_data.get('type', 'NO_EVENT_TYPE_PROVIDED_BY_GHC')

        space_name = event_data.get('space', {}).get('name', '')

        # eventTime does not change for CARD_CLICKED events when events are generated by multiple clicks on the same card
        # seems like it should, but it does not.  So message.lastUpdateTime must also be used in key to avoid missing event
        message_last_update_time = event_data.get('message', {}).get('lastUpdateTime', 'NO_MESSAGE_LASTUPDATETIME_PROVIDED_BY_GHC')

        return "{}{}{}{}".format(event_time, event_type, space_name, message_last_update_time)

    def _handle_event(self, event):
        try:
            data = json.loads(event.data.decode('utf-8'))
        except Exception:
            log.warning('Received malformed event: {}'.format(event.data))
            event.ack()
            return

        event.ack()
        # event.ack() may fail silently, so we should ensure our messages are somewhat idempotent
        event_key = self._event_cache_format_key(data)
        log.info("event received: %s", event_key)
        cached = self.event_cache.get(event_key)
        if cached is not None:
            return
        self.event_cache[event_key] = True

        # https://developers.google.com/chat/reference/message-formats/events
        # https://developers.google.com/chat/how-tos/cards-onclick#receiving_user_click_information
        event_type = data.get('type')
        if event_type == 'MESSAGE':
            return self.handle_event_MESSAGE(data)
        elif event_type == 'CARD_CLICKED':
            return self.handle_event_CARD_CLICKED(data)
        elif data.get('message', {}).get('text') is not None:
            # CARD_CLICKED events do not contain data.message.text field
            log.warn(f"Event type '{event_type}' received, handling as 'MESSAGE' type to support previous backend implementation."
                     "If your code relies on handling this event type as a 'MESSAGE' type, please update your code as this will eventually be deprecated."
                     "You will also need to add an event handler for the specific event type to this backend codebase.")
            return self.handle_event_MESSAGE(data)
        else:
            log.info(f"Unsupported event type '{event_type}' received")
            return

    # https://developers.google.com/chat/how-tos/cards-onclick
    def handle_event_CARD_CLICKED(self, data):
        action_method_name = data.get('action', {}).get('actionMethodName')
        log.info(f"'CARD_CLICKED' event with actionMethodName '{action_method_name}' received")
        action_params = {p['key']: p['value'] for p in data.get('action', {}).get('parameters', [])}

        # this can be extended to handle any arbitrary action_method_names
        if action_method_name == 'bot_command':
            command = action_params.get('command')
            command_args = action_params.get('command_args')
            if command is None:
                log.error("Required 'command' parameter missing for 'bot_command' actionMethodName")
                return
            MESSAGE_data = dict(data)
            MESSAGE_data['message']['text'] = f"{command} {command_args}"
            self.handle_event_MESSAGE(MESSAGE_data)
        else:
            log.info(f"Unsupported CARD_CLICKED event action method name '{action_method_name}' received")

    def handle_event_MESSAGE(self, data):
        # https://developers.google.com/chat/api/guides/message-formats/events#message
        sender_blob = data.get('message', {}).get('sender', {})
        sender = HangoutsChatUser(sender_blob.get('name', ''),
                                  sender_blob.get('displayName', ''),
                                  sender_blob.get('email', ''),
                                  sender_blob.get('type', ''))
        message_body = data['message'].get('text','')
        context = {
            'space_id': data['space']['name'],
            'thread_id': data['message']['thread']['name']
        }

        if 'attachment' in data['message']:
            context['attachment'] = data['message']['attachment']
        # pass httplib2.Http() authenticated handler to errbot. uselful to download attachments
        context['downloader'] = self.chat_api._download

        msg = Message(body=message_body.strip(), frm=sender, extras=context)

        is_dm = data['message']['space']['type'] == 'DM'
        if is_dm:
            msg.to = self.bot_identifier

        self.callback_message(msg)

    def _split_message(self, text, maximum_message_length=GoogleHangoutsChatAPI.max_message_length):
        '''
        Splits a given string up into multiple strings all of length less than some maximum size

        Edge Case: We don't handle the case where one line is big enough for a whole message
        '''
        lines = text.split('\n')
        messages = []
        current_message = ''
        for line in lines:
            if len(current_message) + len(line) + 1 > maximum_message_length:
                messages.append(current_message)
                current_message = line + '\n'
            else:
                current_message += line + '\n'

        messages.append(current_message)
        return messages

    def prep_message_context(self, message):
        space_id = message.extras.get('space_id', None)
        thread_id = message.extras.get('thread_id', None)
        thread_key = message.extras.get('thread_key', None)
        return space_id, thread_id, thread_key

    def send_message(self, message):
        super(GoogleHangoutsChatBackend, self).send_message(message)
        log.info("Sending {}".format(message.body))
        convert_markdown = message.extras.get('markdown', True)
        space_id, thread_id, thread_key = self.prep_message_context(message)
        if not space_id:
            log.info(message.body)
            return
        mentions = message.extras.get('mentions', None)
        text = message.body
        if convert_markdown:
            text = self.md.convert(message.body)
        sub_messages = self._split_message(text)
        log.info("Split message into {} parts".format(len(sub_messages)))
        for message in sub_messages:
            message_payload = {
                'text': message
            }
            if mentions:
                message_payload['annotations'] = []
                for mention in mentions:
                    message_payload['annotations'].append(
                        {
                            "type": "USER_MENTION",
                            "startIndex": mention['start'],
                            "length": mention['length'],
                            "userMention": {
                                "user": {
                                    "name": mention['user_id'],
                                    "displayName": mention['display_name'],
                                    "type": "HUMAN"
                                },
                                "type": "ADD"
                            }
                        }
                    )
            if thread_id:
                message_payload['thread'] = {'name': thread_id}

            gc = self.chat_api.create_message(space_id, message_payload, thread_key)

            # errbot expects no return https://errbot.readthedocs.io/en/latest/errbot.core.html#errbot.core.ErrBot.send_message
            # but we need this in order to get the thread_id from a thread_key generated message

            return None if gc == None else {
                'space_id': gc.get('space', {}).get('name', ''),
                'thread_id': gc.get('thread', {}).get('name', ''),
                'thread_key': thread_key
            }

    # Legacy send_card signature.  This is being deprecated in favor of errbot upstream signature that matches other built-in plugins.
    def send_card_deprecated(self, cards, space_id, thread_id=None):
        log.info("Sending card")
        message_payload = {
            'cards': cards
        }
        if thread_id:
            message_payload['thread'] = {'name': thread_id}

        self.chat_api.create_message(space_id, message_payload)

    # Creates a message body following the card format described in google dev docs
    # https://developers.google.com/chat/reference/message-formats/cards
    def send_card(self, errbot_card: Card, space_id=None, thread_id=None):
        if not isinstance(errbot_card, Card):
            log.warning("deprecated signature of 'send_card' method called, recommend changing to current version that matches upstream signature.")
            return self.send_card_deprecated(errbot_card, space_id, thread_id)

        log.info(f"Sending card {errbot_card.title}...")

        if not errbot_card.title:
            raise MalformedCardError(errbot_card, "'title' field required")
        ghc_card = {
            "header": {
                "title": errbot_card.title,
            },
            "sections": []
        }

        if errbot_card.summary:
            ghc_card['header']['subtitle'] = errbot_card.summary
        if errbot_card.thumbnail:
            ghc_card['header']['imageUrl'] = errbot_card.thumbnail
        if errbot_card.link:
            raise MalformedCardError(errbot_card, "'link' field not supported, please use body field.")
        if errbot_card.fields:
            raise MalformedCardError(errbot_card, "'fields' field not supported, please use body field.")
        if errbot_card.image:
            raise MalformedCardError(errbot_card, "'image' field not supported, please use body field.")
        if errbot_card.color:
            log.debug("card 'color' field not supported.")

        if not errbot_card.body:
            raise MalformedCardError(errbot_card, "'body' field required")

        ghc_card['sections'] = json.loads(errbot_card.body)
        # Example of 'sections' body string:
        # https://developers.google.com/chat/reference/message-formats/cards
        #
        # [
        #     {
        #         "widgets": [
        #             {
        #                 "keyValue": {
        #                     "topLabel": "Order No.",
        #                     "content": "12345"
        #                 }
        #             },
        #         ]
        #     },
        #     {
        #         "header": "Location",
        #         "widgets": [
        #             {
        #                 "image": {
        #                     "imageUrl": "https://maps.googleapis.com/..."
        #                 }
        #             }
        #         ]
        #     },
        #     {
        #         "widgets": [
        #             {
        #                 "buttons": [
        #                     {
        #                         "textButton": {
        #                             "text": "OPEN ORDER",
        #                             "onClick": {
        #                                 "openLink": {
        #                                     "url": "https://example.com/orders/..."
        #                                 }
        #                             }
        #                         }
        #                     }
        #                 ]
        #             }
        #         ]
        #     }
        # ]

        message_payload = {
            'cards': [ghc_card]
        }

        space_id, thread_id, thread_key = self.prep_message_context(errbot_card.parent)
        if not space_id:
            log.info(f"No space_id for card titled '{errbot_card.title}', not sending.")
            return
        if thread_id:
            message_payload['thread'] = {'name': thread_id}

        self.chat_api.create_message(space_id, message_payload, thread_key)

    def serve_forever(self):
        subscription = self._subscribe_to_pubsub_topic(self.gce_project,
                                                       self.gce_topic,
                                                       self.gce_subscription,
                                                       self._handle_event)
        self.connect_callback()

        try:
            import time
            while True:
                time.sleep(10)
        except KeyboardInterrupt:
            log.info("Exiting")
        finally:
            self.disconnect_callback()
            self.shutdown()

    def build_identifier(self, strrep):
        return HangoutsChatUser(None, strrep, None, None)

    def build_reply(self, msg, text=None, private=False, threaded=False):
        response = Message(body=text, frm=msg.to, to=msg.frm, extras=msg.extras)
        return response

    def change_presence(self, status='online', message=''):
        return None

    @property
    def mode(self):
        return 'Google_Hangouts_Chat'

    def query_room(self, room):
        return HangoutsChatRoom(room, self.chat_api)

    def rooms(self):
        spaces = self.chat_api.get_spaces()
        rooms = ['{} ({})'.format(space['displayName'], space['name'])
                 for space in list(spaces) if space['type'] == 'ROOM']

        return rooms
예제 #11
0
class FileBlockchain(BlockchainBase):
    def __init__(
            self,
            *,
            base_directory,

            # Account root files
            account_root_files_subdir=DEFAULT_ACCOUNT_ROOT_FILE_SUBDIR,
            account_root_files_cache_size=128,
            account_root_files_storage_kwargs=None,

            # Blocks
            blocks_subdir=DEFAULT_BLOCKS_SUBDIR,
            block_chunk_size=DEFAULT_BLOCK_CHUNK_SIZE,
            blocks_cache_size=None,
            blocks_storage_kwargs=None,
            lock_filename='file.lock',
            **kwargs):
        if not os.path.isabs(base_directory):
            raise ValueError('base_directory must be an absolute path')

        snapshot_period_in_blocks = kwargs.setdefault(
            'snapshot_period_in_blocks', block_chunk_size)
        super().__init__(**kwargs)

        self.block_chunk_size = block_chunk_size

        account_root_files_directory = os.path.join(base_directory,
                                                    account_root_files_subdir)
        block_directory = os.path.join(base_directory, blocks_subdir)

        self.block_storage = PathOptimizedFileSystemStorage(
            base_path=block_directory, **(blocks_storage_kwargs or {}))
        self.account_root_files_storage = PathOptimizedFileSystemStorage(
            base_path=account_root_files_directory,
            **(account_root_files_storage_kwargs or {}))

        self.account_root_files_cache = LRUCache(account_root_files_cache_size)
        self.blocks_cache = LRUCache(
            # We do not really need to cache more than `snapshot_period_in_blocks` blocks since
            # we use use account root file as a base
            snapshot_period_in_blocks *
            2 if blocks_cache_size is None else blocks_cache_size)

        lock_file_path = os.path.join(base_directory, lock_filename)
        self.file_lock = filelock.FileLock(lock_file_path, timeout=0)

    # Account root files methods
    @lock_method(lock_attr='file_lock', exception=LOCKED_EXCEPTION)
    def add_blockchain_state(self, blockchain_state: BlockchainState):
        return super().add_blockchain_state(blockchain_state)

    @ensure_locked(lock_attr='file_lock', exception=EXPECTED_LOCK_EXCEPTION)
    def persist_blockchain_state(self, account_root_file: BlockchainState):
        storage = self.account_root_files_storage
        last_block_number = account_root_file.last_block_number

        file_path = get_account_root_filename(last_block_number)
        storage.save(file_path,
                     account_root_file.to_messagepack(),
                     is_final=True)

    def _load_account_root_file(self, file_path):
        cache = self.account_root_files_cache
        account_root_file = cache.get(file_path)
        if account_root_file is None:
            storage = self.account_root_files_storage
            assert storage.is_finalized(file_path)
            account_root_file = BlockchainState.from_messagepack(
                storage.load(file_path))
            cache[file_path] = account_root_file

        return account_root_file

    def _yield_blockchain_states(
            self, direction) -> Generator[BlockchainState, None, None]:
        assert direction in (1, -1)

        storage = self.account_root_files_storage
        for file_path in storage.list_directory(sort_direction=direction):
            yield self._load_account_root_file(file_path)

    def yield_blockchain_states(
            self) -> Generator[BlockchainState, None, None]:
        yield from self._yield_blockchain_states(1)

    def yield_blockchain_states_reversed(
            self) -> Generator[BlockchainState, None, None]:
        yield from self._yield_blockchain_states(-1)

    def get_account_root_file_count(self) -> int:
        storage = self.account_root_files_storage
        return ilen(storage.list_directory())

    # Blocks methods
    @lock_method(lock_attr='file_lock', exception=LOCKED_EXCEPTION)
    def add_block(self, block: Block, validate=True):
        return super().add_block(block, validate)

    @ensure_locked(lock_attr='file_lock', exception=EXPECTED_LOCK_EXCEPTION)
    def persist_block(self, block: Block):
        storage = self.block_storage
        block_chunk_size = self.block_chunk_size

        block_number = block.message.block_number
        chunk_number, offset = divmod(block_number, block_chunk_size)

        chunk_block_number_start = chunk_number * block_chunk_size

        if chunk_block_number_start == block_number:
            append_end = block_number
        else:
            assert chunk_block_number_start < block_number
            append_end = block_number - 1

        append_filename = get_block_chunk_filename(
            start=chunk_block_number_start, end=append_end)
        filename = get_block_chunk_filename(start=chunk_block_number_start,
                                            end=block_number)

        storage.append(append_filename, block.to_messagepack())

        if append_filename != filename:
            storage.move(append_filename, filename)

        if offset == block_chunk_size - 1:
            storage.finalize(filename)

    def yield_blocks(self) -> Generator[Block, None, None]:
        yield from self._yield_blocks(1)

    @timeit(verbose_args=True, is_method=True)
    def yield_blocks_reversed(self) -> Generator[Block, None, None]:
        yield from self._yield_blocks(-1)

    def yield_blocks_from(self,
                          block_number: int) -> Generator[Block, None, None]:
        for file_path in self._list_block_directory():
            start, end = get_start_end(file_path)
            if end < block_number:
                continue

            yield from self._yield_blocks_from_file_cached(file_path,
                                                           direction=1,
                                                           start=max(
                                                               start,
                                                               block_number))

    def get_block_by_number(self, block_number: int) -> Optional[Block]:
        block = self.blocks_cache.get(block_number)
        if block is not None:
            return block

        try:
            return next(self.yield_blocks_from(block_number))
        except StopIteration:
            return None

    def get_block_count(self) -> int:
        count = 0
        for file_path in self._list_block_directory():
            start, end = get_start_end(file_path)
            assert start is not None
            assert end is not None

            count += end - start + 1

        return count

    @timeit(verbose_args=True, is_method=True)
    def _yield_blocks(self, direction) -> Generator[Block, None, None]:
        assert direction in (1, -1)

        for file_path in self._list_block_directory(direction):
            yield from self._yield_blocks_from_file_cached(
                file_path, direction)

    def _yield_blocks_from_file_cached(self, file_path, direction, start=None):
        assert direction in (1, -1)

        file_start, file_end = get_start_end(file_path)
        if direction == 1:
            next_block_number = cache_start = file_start if start is None else start
            cache_end = file_end
        else:
            cache_start = file_start
            next_block_number = cache_end = file_end if start is None else start

        for block in self._yield_blocks_from_cache(cache_start, cache_end,
                                                   direction):
            assert next_block_number == block.message.block_number
            next_block_number += direction
            yield block

        if file_start <= next_block_number <= file_end:
            yield from self._yield_blocks_from_file(file_path,
                                                    direction,
                                                    start=next_block_number)

    def _yield_blocks_from_file(self, file_path, direction, start=None):
        assert direction in (1, -1)
        storage = self.block_storage

        unpacker = msgpack.Unpacker()
        unpacker.feed(storage.load(file_path))
        if direction == -1:
            unpacker = always_reversible(unpacker)

        for block_compact_dict in unpacker:
            block = Block.from_compact_dict(block_compact_dict)
            block_number = block.message.block_number
            # TODO(dmu) HIGH: Implement a better skip
            if start is not None:
                if direction == 1 and block_number < start:
                    continue
                elif direction == -1 and block_number > start:
                    continue

            self.blocks_cache[block_number] = block
            yield block

    def _yield_blocks_from_cache(self, start_block_number, end_block_number,
                                 direction):
        assert direction in (1, -1)

        iter_ = range(start_block_number, end_block_number + 1)
        if direction == -1:
            iter_ = always_reversible(iter_)

        for block_number in iter_:
            block = self.blocks_cache.get(block_number)
            if block is None:
                break

            yield block

    def _list_block_directory(self, direction=1):
        storage = self.block_storage
        yield from storage.list_directory(sort_direction=direction)
예제 #12
0
class GMusicLibraryProvider(backend.LibraryProvider):
    root_directory = Ref.directory(uri='gmusic:directory', name='Google Music')

    def __init__(self, *args, **kwargs):
        super(GMusicLibraryProvider, self).__init__(*args, **kwargs)

        # tracks, albums, and artists here refer to what is explicitly
        # in our library.
        self.tracks = {}
        self.albums = {}
        self.artists = {}

        # aa_* caches are *only* used for temporary objects. Library
        # objects will never make it here.
        self.aa_artists = LRUCache(1024)
        self.aa_tracks = LRUCache(1024)
        self.aa_albums = LRUCache(1024)

        self._radio_stations_in_browse = (
            self.backend.config['gmusic']['radio_stations_in_browse'])
        self._radio_stations_count = (
            self.backend.config['gmusic']['radio_stations_count'])
        self._radio_tracks_count = (
            self.backend.config['gmusic']['radio_tracks_count'])

        self._top_tracks_count = (
            self.backend.config['gmusic']['top_tracks_count'])

        # Setup the root of library browsing.
        self._root = [
            Ref.directory(uri='gmusic:album', name='Albums'),
            Ref.directory(uri='gmusic:artist', name='Artists'),
            Ref.directory(uri='gmusic:track', name='Tracks')
        ]

        if self._radio_stations_in_browse:
            self._root.append(Ref.directory(uri='gmusic:radio', name='Radios'))

    @property
    def all_access(self):
        return self.backend.session.all_access

    def _browse_tracks(self):
        tracks = list(self.tracks.values())
        tracks.sort(key=lambda ref: ref.name)
        refs = []
        for track in tracks:
            refs.append(track_to_ref(track))
        return refs

    def _browse_albums(self):
        refs = []
        for album in self.albums.values():
            refs.append(album_to_ref(album))
        refs.sort(key=lambda ref: ref.name)
        return refs

    def _browse_album(self, uri):
        refs = []
        for track in self._lookup_album(uri):
            refs.append(track_to_ref(track, True))
        return refs

    def _browse_artists(self):
        refs = []
        for artist in self.artists.values():
            refs.append(artist_to_ref(artist))
        refs.sort(key=lambda ref: ref.name)
        return refs

    def _browse_artist(self, uri):
        refs = []
        for album in self._get_artist_albums(uri):
            refs.append(album_to_ref(album))
            refs.sort(key=lambda ref: ref.name)
        if len(refs) > 0:
            refs.insert(0, Ref.directory(uri=uri + ':all', name='All Tracks'))
            is_all_access = uri.startswith('gmusic:artist:A')
            if is_all_access:
                refs.insert(1,
                            Ref.directory(uri=uri + ':top', name='Top Tracks'))
            return refs
        else:
            # Show all tracks if no album is available
            return self._browse_artist_all_tracks(uri)

    def _browse_artist_all_tracks(self, uri):
        artist_uri = ':'.join(uri.split(':')[:3])
        refs = []
        tracks = self._lookup_artist(artist_uri, True)
        for track in tracks:
            refs.append(track_to_ref(track))
        return refs

    def _browse_artist_top_tracks(self, uri):
        artist_uri = ':'.join(uri.split(':')[:3])
        refs = []
        tracks = self._get_artist_top_tracks(artist_uri)
        for track in tracks:
            refs.append(track_to_ref(track))
        return refs

    def _browse_radio_stations(self, uri):
        stations = self.backend.session.get_radio_stations(
            self._radio_stations_count)
        # create Ref objects
        refs = []
        for station in stations:
            refs.append(
                Ref.directory(uri='gmusic:radio:' + station['id'],
                              name=station['name']))
        return refs

    def _browse_radio_station(self, uri):
        station_id = uri.split(':')[2]
        tracks = self.backend.session.get_station_tracks(
            station_id, self._radio_tracks_count)

        # create Ref objects
        refs = []
        for track in tracks:
            mopidy_track = self._to_mopidy_track(track)
            self.aa_tracks[mopidy_track.uri] = mopidy_track
            refs.append(track_to_ref(mopidy_track))
        return refs

    def browse(self, uri):
        logger.debug('browse: %s', str(uri))
        if not uri:
            return []
        if uri == self.root_directory.uri:
            return self._root

        parts = uri.split(':')

        # tracks
        if uri == 'gmusic:track':
            return self._browse_tracks()

        # albums
        if uri == 'gmusic:album':
            return self._browse_albums()

        # a single album
        # uri == 'gmusic:album:album_id'
        if len(parts) == 3 and parts[1] == 'album':
            return self._browse_album(uri)

        # artists
        if uri == 'gmusic:artist':
            return self._browse_artists()

        # a single artist
        # uri == 'gmusic:artist:artist_id'
        if len(parts) == 3 and parts[1] == 'artist':
            return self._browse_artist(uri)

        # all tracks of a single artist
        # uri == 'gmusic:artist:artist_id:all'
        if len(parts) == 4 and parts[1] == 'artist' and parts[3] == 'all':
            return self._browse_artist_all_tracks(uri)

        # top tracks of a single artist
        # uri == 'gmusic:artist:artist_id:top'
        if len(parts) == 4 and parts[1] == 'artist' and parts[3] == 'top':
            return self._browse_artist_top_tracks(uri)

        # all radio stations
        if uri == 'gmusic:radio':
            return self._browse_radio_stations(uri)

        # a single radio station
        # uri == 'gmusic:radio:station_id'
        if len(parts) == 3 and parts[1] == 'radio':
            return self._browse_radio_station(uri)

        logger.debug('Unknown uri for browse request: %s', uri)

        return []

    def lookup(self, uri):
        if uri.startswith('gmusic:track:'):
            return self._lookup_track(uri)
        elif uri.startswith('gmusic:album:'):
            return self._lookup_album(uri)
        elif uri.startswith('gmusic:artist:'):
            return self._lookup_artist(uri)
        else:
            return []

    def _lookup_track(self, uri):
        is_all_access = uri.startswith('gmusic:track:T')

        try:
            return [self.tracks[uri]]
        except KeyError:
            logger.debug('Track not a library track %r', uri)
            pass

        if is_all_access and self.all_access:
            track = self.aa_tracks.get(uri)
            if track:
                return [track]
            song = self.backend.session.get_track_info(uri.split(':')[2])
            if song is None:
                logger.warning('There is no song %r', uri)
                return []
            if 'artistId' not in song:
                logger.warning('Failed to lookup %r', uri)
                return []
            mopidy_track = self._to_mopidy_track(song)
            self.aa_tracks[mopidy_track.uri] = mopidy_track
            return [mopidy_track]
        else:
            return []

    def _lookup_album(self, uri):
        is_all_access = uri.startswith('gmusic:album:B')
        if self.all_access and is_all_access:
            tracks = self.aa_albums.get(uri)
            if tracks:
                return tracks
            album = self.backend.session.get_album_info(uri.split(':')[2],
                                                        include_tracks=True)
            if album and album.get('tracks'):
                tracks = [
                    self._to_mopidy_track(track) for track in album['tracks']
                ]
                for track in tracks:
                    self.aa_tracks[track.uri] = track
                tracks = sorted(tracks, key=lambda t: (t.disc_no, t.track_no))
                self.aa_albums[uri] = tracks
                return tracks

            logger.warning('Failed to lookup all access album %r: %r', uri,
                           album)

        # Even if the album has an all access ID, we need to look it
        # up here (as a fallback) because purchased tracks can have a
        # store ID, but only show up in your library.
        try:
            album = self.albums[uri]
        except KeyError:
            logger.debug('Failed to lookup %r', uri)
            return []

        tracks = self._find_exact(
            dict(album=album.name,
                 artist=[artist.name for artist in album.artists],
                 date=album.date)).tracks
        return sorted(tracks, key=lambda t: (t.disc_no, t.track_no))

    def _get_artist_top_tracks(self, uri):
        is_all_access = uri.startswith('gmusic:artist:A')
        artist_id = uri.split(':')[2]

        if not is_all_access:
            logger.debug("Top Tracks not available for non-all-access artists")
            return []

        artist_info = self.backend.session.get_artist_info(
            artist_id,
            include_albums=False,
            max_top_tracks=self._top_tracks_count,
            max_rel_artist=0)
        top_tracks = []

        for track_dict in artist_info['topTracks']:
            top_tracks.append(self._to_mopidy_track(track_dict))

        return top_tracks

    def _get_artist_albums(self, uri):
        is_all_access = uri.startswith('gmusic:artist:A')

        artist_id = uri.split(':')[2]
        if is_all_access:
            # all access
            artist_infos = self.backend.session.get_artist_info(
                artist_id, max_top_tracks=0, max_rel_artist=0)
            if artist_infos is None or 'albums' not in artist_infos:
                return []
            albums = []
            for album in artist_infos['albums']:
                albums.append(
                    self._aa_search_album_to_mopidy_album({'album': album}))
            return albums
        elif self.all_access and artist_id in self.aa_artists:
            albums = self._get_artist_albums('gmusic:artist:%s' %
                                             self.aa_artists[artist_id])
            if len(albums) > 0:
                return albums
            # else fall back to non aa albums
        if uri in self.artists:
            artist = self.artists[uri]
            return [
                album for album in self.albums.values()
                if artist in album.artists
            ]
        else:
            logger.debug('0 albums available for artist %r', uri)
            return []

    def _lookup_artist(self, uri, exact_match=False):
        def sorter(track):
            return (
                track.album.date,
                track.album.name,
                track.disc_no,
                track.track_no,
            )

        if self.all_access:
            try:
                all_access_id = self.aa_artists[uri.split(':')[2]]
                artist_infos = self.backend.session.get_artist_info(
                    all_access_id, max_top_tracks=0, max_rel_artist=0)
                if not artist_infos or not artist_infos['albums']:
                    logger.warning('Failed to lookup %r', artist_infos)
                tracks = [
                    self._lookup_album('gmusic:album:' + album['albumId'])
                    for album in artist_infos['albums']
                ]
                tracks = reduce(lambda a, b: (a + b), tracks)
                return sorted(tracks, key=sorter)
            except KeyError:
                pass
        try:
            artist = self.artists[uri]
        except KeyError:
            logger.debug('Failed to lookup %r', uri)
            return []

        tracks = self._find_exact(dict(artist=artist.name)).tracks
        if exact_match:
            tracks = filter(lambda t: artist in t.artists, tracks)
        return sorted(tracks, key=sorter)

    def refresh(self, uri=None):
        self.tracks = {}
        self.albums = {}
        self.artists = {}

        album_tracks = {}
        for track in self.backend.session.get_all_songs():
            mopidy_track = self._to_mopidy_track(track)

            self.tracks[mopidy_track.uri] = mopidy_track
            self.albums[mopidy_track.album.uri] = mopidy_track.album

            # We don't care about the order because we're just using
            # this as a temporary variable to grab the proper album
            # artist out of the album.
            if mopidy_track.album.uri not in album_tracks:
                album_tracks[mopidy_track.album.uri] = []

            album_tracks[mopidy_track.album.uri].append(mopidy_track)

        # Yes, this is awful. No, I don't have a better solution. Yes,
        # I'm annoyed at Google for not providing album artist IDs.
        for album in self.albums.values():
            artist_found = False
            for album_artist in album.artists:
                for track in album_tracks[album.uri]:
                    for artist in track.artists:
                        if album_artist.name == artist.name:
                            artist_found = True
                            self.artists[artist.uri] = artist

            if not artist_found:
                for artist in album.artists:
                    self.artists[artist.uri] = artist

    def search(self, query=None, uris=None, exact=False):
        if exact:
            return self._find_exact(query=query, uris=uris)

        lib_tracks, lib_artists, lib_albums = self._search_library(query, uris)

        if query:
            aa_tracks, aa_artists, aa_albums = self._search(query, uris)
            for aa_artist in aa_artists:
                lib_artists.add(aa_artist)

            for aa_album in aa_albums:
                lib_albums.add(aa_album)

            lib_tracks = set(lib_tracks)

            for aa_track in aa_tracks:
                lib_tracks.add(aa_track)

        return SearchResult(uri='gmusic:search',
                            tracks=lib_tracks,
                            artists=lib_artists,
                            albums=lib_albums)

    def _find_exact(self, query=None, uris=None):
        # Find exact can only be done on gmusic library,
        # since one can't filter all access searches
        lib_tracks, lib_artists, lib_albums = self._search_library(query, uris)

        return SearchResult(uri='gmusic:search',
                            tracks=lib_tracks,
                            artists=lib_artists,
                            albums=lib_albums)

    def _search(self, query=None, uris=None):
        for (field, values) in query.iteritems():
            if not hasattr(values, '__iter__'):
                values = [values]

            # Since gmusic does not support search filters, just search for the
            # first 'searchable' filter
            if field in [
                    'track_name', 'album', 'artist', 'albumartist', 'any'
            ]:
                logger.info('Searching Google Play Music for: %s', values[0])
                res = self.backend.session.search(values[0], max_results=50)
                if res is None:
                    return [], [], []

                albums = [
                    self._aa_search_album_to_mopidy_album(album_res)
                    for album_res in res['album_hits']
                ]
                artists = [
                    self._aa_search_artist_to_mopidy_artist(artist_res)
                    for artist_res in res['artist_hits']
                ]
                tracks = [
                    self._aa_search_track_to_mopidy_track(track_res)
                    for track_res in res['song_hits']
                ]

                return tracks, artists, albums

        return [], [], []

    def _search_library(self, query=None, uris=None):
        if query is None:
            query = {}
        self._validate_query(query)
        result_tracks = self.tracks.values()

        for (field, values) in query.iteritems():
            if not hasattr(values, '__iter__'):
                values = [values]
            # FIXME this is bound to be slow for large libraries
            for value in values:
                if field == 'track_no':
                    q = self._convert_to_int(value)
                else:
                    q = value.strip().lower()

                def uri_filter(track):
                    return q in track.uri.lower()

                def track_name_filter(track):
                    return q in track.name.lower()

                def album_filter(track):
                    return q in getattr(track, 'album', Album()).name.lower()

                def artist_filter(track):
                    return (any(q in a.name.lower() for a in track.artists)
                            or albumartist_filter(track))

                def albumartist_filter(track):
                    album_artists = getattr(track, 'album', Album()).artists
                    return any(q in a.name.lower() for a in album_artists)

                def track_no_filter(track):
                    return track.track_no == q

                def date_filter(track):
                    return track.date and track.date.startswith(q)

                def any_filter(track):
                    return any([
                        uri_filter(track),
                        track_name_filter(track),
                        album_filter(track),
                        artist_filter(track),
                        albumartist_filter(track),
                        date_filter(track),
                    ])

                if field == 'uri':
                    result_tracks = filter(uri_filter, result_tracks)
                elif field == 'track_name':
                    result_tracks = filter(track_name_filter, result_tracks)
                elif field == 'album':
                    result_tracks = filter(album_filter, result_tracks)
                elif field == 'artist':
                    result_tracks = filter(artist_filter, result_tracks)
                elif field == 'albumartist':
                    result_tracks = filter(albumartist_filter, result_tracks)
                elif field == 'track_no':
                    result_tracks = filter(track_no_filter, result_tracks)
                elif field == 'date':
                    result_tracks = filter(date_filter, result_tracks)
                elif field == 'any':
                    result_tracks = filter(any_filter, result_tracks)
                else:
                    raise LookupError('Invalid lookup field: %s' % field)

        result_artists = set()
        result_albums = set()
        for track in result_tracks:
            result_artists |= track.artists
            result_albums.add(track.album)

        return result_tracks, result_artists, result_albums

    def _validate_query(self, query):
        for (_, values) in query.iteritems():
            if not values:
                raise LookupError('Missing query')
            for value in values:
                if not value:
                    raise LookupError('Missing query')

    def _to_mopidy_track(self, song):
        track_id = song.get('id', song.get('nid'))
        if track_id is None:
            raise ValueError
        if track_id[0] != "T" and "-" not in track_id:
            track_id = "T" + track_id
        return Track(uri='gmusic:track:' + track_id,
                     name=song['title'],
                     artists=[self._to_mopidy_artist(song)],
                     album=self._to_mopidy_album(song),
                     track_no=song.get('trackNumber', 1),
                     disc_no=song.get('discNumber', 1),
                     date=unicode(song.get('year', 0)),
                     length=int(song['durationMillis']),
                     bitrate=320)

    def _to_mopidy_album(self, song):
        name = song.get('album', '')
        artist = self._to_mopidy_album_artist(song)
        date = unicode(song.get('year', 0))

        album_id = song.get('albumId')
        if album_id is None:
            album_id = create_id(artist.name + name + date)

        uri = 'gmusic:album:' + album_id
        images = get_images(song)
        return Album(uri=uri,
                     name=name,
                     artists=[artist],
                     num_tracks=song.get('totalTrackCount'),
                     num_discs=song.get('totalDiscCount'),
                     date=date,
                     images=images)

    def _to_mopidy_artist(self, song):
        name = song.get('artist', '')
        artist_id = song.get('artistId')
        if artist_id is not None:
            artist_id = artist_id[0]
        else:
            artist_id = create_id(name)
        uri = 'gmusic:artist:' + artist_id
        return Artist(uri=uri, name=name)

    def _to_mopidy_album_artist(self, song):
        name = song.get('albumArtist', '')
        if name.strip() == '':
            name = song.get('artist', '')
        uri = 'gmusic:artist:' + create_id(name)
        return Artist(uri=uri, name=name)

    def _aa_search_track_to_mopidy_track(self, search_track):
        track = search_track['track']

        aa_artist_id = create_id(track['artist'])
        if 'artistId' in track:
            aa_artist_id = track['artistId'][0]
        else:
            logger.warning('No artistId for Track %r', track)

        artist = Artist(uri='gmusic:artist:' + aa_artist_id,
                        name=track['artist'])

        album = Album(uri='gmusic:album:' + track['albumId'],
                      name=track['album'],
                      artists=[artist],
                      date=unicode(track.get('year', 0)))

        return Track(uri='gmusic:track:' + track['storeId'],
                     name=track['title'],
                     artists=[artist],
                     album=album,
                     track_no=track.get('trackNumber', 1),
                     disc_no=track.get('discNumber', 1),
                     date=unicode(track.get('year', 0)),
                     length=int(track['durationMillis']),
                     bitrate=320)

    def _aa_search_artist_to_mopidy_artist(self, search_artist):
        artist = search_artist['artist']
        uri = 'gmusic:artist:' + artist['artistId']
        return Artist(uri=uri, name=artist['name'])

    def _aa_search_album_to_mopidy_album(self, search_album):
        album = search_album['album']
        uri = 'gmusic:album:' + album['albumId']
        name = album['name']
        artist = self._aa_search_artist_album_to_mopidy_artist_album(album)
        date = unicode(album.get('year', 0))
        return Album(uri=uri, name=name, artists=[artist], date=date)

    def _aa_search_artist_album_to_mopidy_artist_album(self, album):
        name = album.get('albumArtist', '')
        if name.strip() == '':
            name = album.get('artist', '')
        uri = 'gmusic:artist:' + album.get('artistId')[0]
        return Artist(uri=uri, name=name)

    def _convert_to_int(self, string):
        try:
            return int(string)
        except ValueError:
            return object()
예제 #13
0
class CompHetsUnit(FunctionUnit):
    @staticmethod
    def makeIt(ds_h, descr, before=None, after=None):
        unit_h = CompHetsUnit(ds_h, descr)
        ds_h.getEvalSpace()._insertUnit(unit_h, before=before, after=after)

    def __init__(self, ds_h, descr):
        FunctionUnit.__init__(self,
                              ds_h.getEvalSpace(),
                              descr,
                              sub_kind="comp-hets",
                              parameters=["approx", "state"])
        self.mZygSupport = ds_h.getZygositySupport()
        self.mOpCache = LRUCache(
            AnfisaConfig.configOption("comp-hets.cache.size"))

    def _buildTrioRequest(self, trio_info, approx_mode, actual_condition):
        id_base, id_father, id_mother = trio_info[1:]
        c_rq = [[1, {
            "1": [id_base, id_father],
            "0": [id_mother]
        }], [1, {
            "1": [id_base, id_mother],
            "0": [id_father]
        }]]
        return self.mZygSupport.makeCompoundRequest(approx_mode,
                                                    actual_condition, c_rq,
                                                    self.getName())

    def buildConditions(self, approx_mode, actual_condition):
        ret_handle = dict()
        for trio_info in self.mZygSupport.getTrioSeq():
            ret_handle[trio_info[0]] = self._buildTrioRequest(
                trio_info, approx_mode, actual_condition)
        return ret_handle

    def iterComplexCriteria(self, context, variants=None):
        if context is None:
            return
        trio_dict = context["trio-dict"]
        for trio_id, _, _, _ in self.mZygSupport.getTrioSeq():
            if variants is not None and trio_id not in variants:
                continue
            trio_crit = trio_dict[trio_id]
            if trio_crit is not None:
                yield trio_id, trio_crit

    def makeInfoStat(self, eval_h, point_no):
        ret_handle = self.prepareStat()
        ret_handle["trio-variants"] = [
            trio_info[0] for trio_info in self.mZygSupport.getTrioSeq()
        ]
        ret_handle["approx-modes"] = self.mZygSupport.getApproxInfo()
        ret_handle["labels"] = eval_h.getLabelPoints(point_no)
        return ret_handle

    def _locateContext(self, parameters, eval_h, point_no=None):
        if parameters.get("state"):
            actual_condition = eval_h.getLabelCondition(
                parameters["state"], point_no)
            if actual_condition is None:
                return None, ("State label %s not defined" %
                              parameters["state"])
        else:
            actual_condition = eval_h.getActualCondition(point_no)
        approx_mode = self.mZygSupport.normalizeApprox(
            parameters.get("approx"))
        if approx_mode is False:
            return None, "Improper approx mode %s" % parameters["approx"]

        build_id = approx_mode + '|' + actual_condition.hashCode()
        with self.getEvalSpace().getDS():
            context = self.mOpCache.get(build_id)
        if context is None:
            context = {
                "approx": approx_mode,
                "trio-dict": self.buildConditions(approx_mode,
                                                  actual_condition)
            }
            with self.getEvalSpace().getDS():
                self.mOpCache[build_id] = context
        if None in context["trio-dict"].values():
            return context, "Too heavy condition"
        return context, None

    def locateContext(self, cond_data, eval_h):
        point_no, _ = eval_h.locateCondData(cond_data)
        context, err_msg = self._locateContext(cond_data[4], eval_h, point_no)
        if err_msg:
            eval_h.operationError(cond_data, err_msg)
        return context

    def validateArgs(self, parameters):
        if (parameters.get("state")
                and not isinstance(parameters["state"], str)):
            return "Bad state parameter"
        if (parameters.get("approx")
                and not isinstance(parameters["approx"], str)):
            return "Bad approx parameter"
        return None

    def makeParamStat(self, condition, parameters, eval_h, point_no):
        context, err_msg = self._locateContext(parameters, eval_h, point_no)
        ret_handle = self.prepareStat()
        ret_handle.update(parameters)
        if err_msg:
            ret_handle["err"] = err_msg
        else:
            self.collectComplexStat(
                ret_handle, condition, context,
                self.mZygSupport.getGeneUnit(
                    context.get("approx")).isDetailed())

        return ret_handle
예제 #14
0
class radAnonimizer(Gdcmanon.AnonWrapper):
    malenames = [
        'ANASTAZY', 'BONIFACY', 'CECYL', 'DOBROMIR', 'EUSTACHY', 'FABIUSZ',
        'GERWAZY', 'HIACYNT', 'IDZI', 'JACENTY', 'KASJAN', 'LAMBERT', 'MAKARY',
        'NEPOMOUCEN', 'ONUFRY', 'PONCYLIUSZ', 'RAJMUND', 'SERWACY', 'TEOBALD',
        'URBAN', 'WIT', 'ZACHARIASZ', 'LETO', 'TWOFLOWER', 'RUMCAJS',
        'HAVELOCK', 'ELVIS', 'RAJMUND', 'BILL', 'CLIVE', 'BUTCH', 'PETER',
        "ALEXANDRE", "ALFRED", "ARNOLD", "AURELIUSZ", "BARNABA", "BEN",
        "DMYTRO", "DYMITR", "ELIJAH", "EMMANUEL", "FINN", "FRANCIS", "GEORGE",
        "GROMOSLAW", "HORACY", "JOHANNES", "JOSEPH", "KONSTANTYN", "LUBOMIR",
        "LUCA", "MASON", "NICO", "RAFAEL", "RICHARD", "ROBIN", "ROGER",
        "SALOMON", "SASZA", "SIEMOWIT", "STANISLAW", "TIM", "TIMOTHY", "TYMUR",
        "VOLODYMYR", "YASIN", "YASSIN", "ZACHAR", "ZAHAR", "AMIN", "ANGELO",
        "BORIS", "BRANDON", "CASPIAN", "DENNIS", "DIONIZY", "EGOR",
        "EUZEBIUSZ", "FABIO", "HENRIK", "ILYA", "ISMAEL", "JASPER", "JAYDEN",
        "JURIJ", "KONSTANTIN", "KSAVIER", "LENART", "LEONID", "LESLAW", "LEVI",
        "LUDWIG", "MARCELLO", "MAXIME", "MAXIMILIANO", "MELCHIOR", "MILOSZ",
        "MODEST", "NATANEL", "NIKLAS", "OLEH", "OLEKSII", "OMAR", "PAUL",
        "RAJAN", "RODRIGO", "RUDOLF", "SAM", "THIAGO", "TOMMY", "TRISTAN",
        "VITALI", "XAVERY", "ZAWISZA", "ABDULLAH", "ADNAN", "AIDAN", "AIDEN",
        "ALEJANDRO", "ALEKSANDR", "ALEKSIEJ", "ALESSIO", "ANDREAS", "ANDREI",
        "ANDRIY", "ANTONIUSZ", "ARYAN", "AUGUSTIN", "BALTAZAR", "BJORN",
        "CHARLES", "CRISTIAN", "CZCIBOR", "DANIL", "DEXTER", "DMITRIJ",
        "ELLIOT", "ENES", "ENZO", "ESTEBAN", "FREDERICK", "GERALD", "GIOVANNI",
        "GUSTAV", "HAMZA", "HERBERT", "IBRAHIM", "IDZI", "ILIAN", "ILIAS",
        "ISAAC", "JAROSLAW", "JEGOR", "JOSEF", "JOZUE", "JUSTIN", "KENAN",
        "KIRILL", "KRISTIAN", "KSAVERY", "LIONEL", "MALIK", "MARCELL", "MARIO",
        "MARSEL", "MASSIMO", "MATEI", "MATTHIAS", "MATVIY", "MIKE", "MIKHAIL",
        "MILOSLAW", "MIRAN", "MIROSLAW", "MUSA", "NATANAEL", "NESTOR", "NHAT",
        "OKTAWIUSZ", "ORLANDO", "PAVLO", "RADOSZ", "RAJMUND", "RAMI", "ROMEO",
        "S***N", "SWIATOSLAW", "TIAGO", "TIMON", "TOBY", "TOMMASO", "UMAR",
        "VALENTINO", "VITALII", "VLADIMIR", "WALERY", "WIESLAW", "WITALIJ",
        "WLADYSLAW", "XANDER", "YAROSLAV", "ZAC", "ZACK", "ZBYSZKO", "AAYAN",
        "ADRIEN", "AKIM", "ALBANO", "ALEC", "ALEKSEJ", "ALEN", "ALEXANDROS",
        "ALEXY", "ALPER", "ALVARO", "AMINE", "ANDRE", "ANGEL", "ARCHIBALD",
        "ARIS", "ARMANDO", "ARSENIUSZ", "ARTEMIJ", "ASEN", "ASHER", "ATANAZY",
        "ATHARVA", "ATTILA", "AYAN", "AYAZ", "BOZYDAR", "CHARLIE", "COLLIN",
        "CONAN", "CRISTIANO", "CZAREK", "DAGMAR", "DANG", "DANTE", "DARIAN",
        "DARIO", "DASTIN", "DAVIDE", "DAVINCI", "DAVIT", "DEMIAN", "DEMJAN",
        "DENIEL", "DENIZ", "DRAGOMIR", "EDUARD", "EINAR", "EKAM", "ELI",
        "ELIOTT", "ELYAS", "EREN", "EUSTACHY", "EZEL", "FARES", "FERENC",
        "FRANCISCO", "FRANCO", "FRANKIE", "FRANKO", "GABRIELE", "GAGIK",
        "GAWEL", "GERALT", "GIA", "GIUSEPPE", "GLEB", "GOR", "GORAN", "GWIDON",
        "HAI", "HARIS", "HARUN", "HAYK", "HENDRIK", "HOANG", "HOANG", "HOANG",
        "IDRIS", "ILJA", "ILLIA", "IMRAN", "JAMAL", "JANUARY", "JAROSLAV",
        "JAVIER", "JOAN", "JOHANN", "JONAS", "JOSE", "JOSZKO", "JOZEF", "JUAN",
        "JUDA", "JULIEN", "JULIO", "JULIUS", "JUNHAO", "JURII", "KAMIL",
        "KEITH", "KENZO", "KEREM", "KHALID", "KIRIL", "KORAY", "KOSTEK",
        "KOSTIANTYN", "LARGO", "LEO", "LEON", "LINUS", "LIO", "LIVIAN",
        "LLOYD", "LONGIN", "LOTAR", "LUCIAN", "MACIEK", "MALAKAI", "MARCELIN",
        "MARCELINO", "MATEUSH", "MATHEO", "MATIAS", "MATTIA", "MATVEJ",
        "MATVEY", "MATWIEJ", "MATWIJ", "MAURICE", "MAXIMILIEN", "MAXIMILLIAN",
        "MAXWELL", "MERGEN", "MICHEL", "MIHAIL", "MIKEL", "MINH", "MINH",
        "MINH", "MIROSLAV", "MOHAMED", "MORGAN", "MUSTAFA", "MYKOLA", "MYRON",
        "NAREK", "NAWOJ", "NAZARII", "NIKOLAI", "NORMAN", "OCTAVIAN", "ORHAN",
        "OSTAP", "OTTO", "PAULO", "PAVEL", "PAWEL", "PEDRO", "PHAN",
        "PHILIPPE", "QUENTIN", "RAPHAEL", "RAUL", "RAVI", "REMI", "REYANSH",
        "RICARDO", "ROBERTO", "ROHAN", "RONALD", "RUDRA", "RUFIN", "RUSLAN",
        "SALAH", "SALVADOR", "SAMSON", "SELIM", "SELIM", "SEMIR", "SERGIO",
        "SERHII", "SHER", "SHIVANSH", "SINAN", "SLAWOJ", "SOBIESLAW",
        "STANISLAS", "STANLEY", "STEVEN", "SULEIMAN", "SVEN", "SVIATOSLAV",
        "SCIBOR", "TADEJ", "TARAS", "TEOMAN", "THEO", "TIGRAN", "TIMO",
        "TIMOFEI", "TOM", "TOMEK", "TOMIR", "TONI", "TRAIAN", "TUAN",
        "VENIAMIN", "WADIM", "WALENTY", "WALTER", "WITEK", "WITOSZ",
        "WLADIMIR", "WOJMIR", "WOLFGANG", "YAKUB", "YAREMA", "YEHOR", "YURI",
        "YUVAAN", "ZAID", "ZAKARIYA", "ZAYAN", "ZBYSZEK"
    ]
    femalenames = [
        'ALFREDA', 'BIANKA', 'CECYLIA', 'DELFINA', 'EUFEMIA', 'FILOMENA',
        'GERTRUDA', 'IMELDA', 'JOLENTA', 'KLOTYLDA', 'LUTGARDA', 'MECHTYLDA',
        'NIKODEMA', 'ODYLIA', 'OKTAWIA', 'PELAGIA', 'RUTA', 'SCHOLASTYKA',
        'TEKLA', 'UMA', 'WERIDIANA', 'ZYTA', 'ELEANOR', 'ARWEN', 'PERSEFONA',
        'NIKODEZJA', 'ROZA', 'STOKROTKA', 'CIRI', 'ALEXANDRA', "ELA", "GRACJA",
        "GRACJANA", "INGRID", "IVANKA", "JENNIFER", "JESIKA", "KAMELIA",
        "KATERYNA", "KRISTINA", "LATIKA", "LEOKADIA", "LINA", "LINDA", "MAIA",
        "MANUELA", "MARIIA", "MARYAM", "MELODY", "MICHAELA", "MIKA", "NAWOJKA",
        "PAMELA", "PETRA", "ROSE", "SCARLETT", "SLAWA", "SOLOMIIA", "SUSANNA",
        "TETIANA", "TINA", "ULIANA", "VALENTINA", "YASMINA", "AIDA", "ALANA",
        "ALDONA", "ALENA", "ANATOLIA", "ATHENA", "AUDREY", "BERNADETTA",
        "BOGUSLAWA", "CHANEL", "CYNTIA", "ELEANOR", "ELSA", "EMILI", "EVELINA",
        "GEORGIA", "GIULIA", "JAGIENKA", "JULIE", "KAYLA", "LEONIE", "LIANA",
        "MADELEINE", "MALENA", "MARIAM", "MARIANA", "MELA", "MILLA", "MILA",
        "MOLLY", "NATASHA", "NELLI", "OLENA", "RAISA", "ROSALIA", "SAMANTHA",
        "SAVANNAH", "SELIN", "SUSANNE", "TESSA", "VERA", "VERONICA",
        "WALENTYNA", "YASMINE", "YEVA", "ZARINA", "ZOEY", "ADEL", "AILA",
        "ALEKSA", "ALISHA", "ALYA", "AMBER", "ANABELLA", "ANGELICA",
        "ANNABELLE", "ASYA", "AYLIN", "BERNADETA", "BOGDANA", "BRONISLAWA",
        "CARLOTTA", "CATTLEYA", "CIRILLA", "DAMROKA", "DANIELLA", "DINA",
        "DONATA", "ELINA", "ELISE", "ELLIE", "ESMERALDA", "FABIOLA", "FLAWIA",
        "FREJA", "GRACE", "ILIA", "INDIA", "ISABEL", "ISLA", "ISMENA", "IVA",
        "IVY", "KAJRA", "KAMILLA", "KATRINA", "LAJLA", "LAURENCJA", "LEIA",
        "LEONA", "LENA", "LILIAN", "LILLIAN", "MAGNOLIA", "MARGARET",
        "MARGARYTA", "MARISA", "MASZA", "MELANI", "MIRELA", "NATALIIA",
        "NEYLA", "NIKOLETA", "NIKOLINA", "NILA", "OTOLIA", "PELAGIA", "PIA",
        "POLIANA", "PRISHA", "RACHELA", "ROSALIE", "ROZALINA", "RUBY", "RUT",
        "RUTA", "SASHA", "SELMA", "SOFIYA", "TAIDA", "VALERIIA", "VITTORIA",
        "VIVIAN", "WIERA", "XENIA", "AAHANA", "AALIYAH", "ADELIA", "AIMEE",
        "AISZA", "AJSZA", "ALESSANDRA", "ALESSIA", "ALEXIA", "ALISSIA",
        "ALITA", "ALVIRA", "ALYSSA", "AMAIA", "AMALIA", "AMARACHI", "AMELA",
        "AMELIA", "ANABEL", "ANABELA", "ANABELL", "ANABELLE", "ANELIA",
        "ANETTA", "ANGEL", "ANIA", "ANYA", "ARLENA", "ARLETTA", "ARLO",
        "ARNIKA", "ASEL", "ASHLEY", "ASIYA", "ASTEJA", "AUGUSTYNA", "AURIKA",
        "AYESHA", "AYSE", "BAO", "CAMILA", "CANSU", "CARLA", "CAROLINA",
        "CAROLINE", "CELIA", "CELINE", "CLAIRE", "CLARISSA", "CYNTHIA", "DANA",
        "DARINA", "DORA", "ELENI", "ELIZABET", "ELMIRA", "EMA", "EMANUELA",
        "EMILIANA", "EMILIIA", "EMINE", "ERIN", "ESTELLE", "EULALIA",
        "FLORENTINA", "FRIDA", "GABI", "GABRYJELA", "GEMMA", "GIA", "HELEN",
        "HOAI", "ILARIA", "IMAN", "INES", "IRINA", "IRIS", "IRMA", "IVANNA",
        "IWA", "IWANKA", "JAGA", "JEWA", "JOZEFA", "JUDITH", "KAMILIA",
        "KARLA", "KATHRIN", "KHLOE", "KIM", "KLARYSA", "KLEOPATRA", "KORDELIA",
        "LANA", "LEJLA", "LENNA", "LETI", "LEA", "LILLIANA", "LIUBOV", "LIA",
        "LORA", "LOUISA", "LOUISE", "LUCIA", "LUCJA", "LUSI", "MADLEN",
        "MALVINA", "MARCELINE", "MARGOT", "MARIJA", "MARYNA", "MASHA", "MAURA",
        "MELEK", "MIGLENA", "MINH", "MIYA", "NADIN", "NADIYA", "NADJA",
        "NAILA", "NARE", "NELLIE", "NGOC", "NORA", "PHOEBE", "QIANYU",
        "RAMONA", "ROZA", "SABRINA", "SAIDA", "SAWA", "SCARLET", "SELINA",
        "SEMILIANA", "SERAFINA", "SIENNA", "SIMONE", "SOFI", "SOFII", "SOFIKO",
        "SOFIA", "SOLOMIA", "SOLOMIJA", "SORAYA", "SUMAYA", "SUSAN", "SUZAN",
        "SUZANNA", "SYNTIA", "SZYMON", "TALIA", "TALITA", "TEODOZJA",
        "TEOFILA", "THANH", "THEA", "THIEN", "TIANTIAN", "TUE", "ULJANA",
        "ULLA", "VANESA", "VARVARA", "VIKTORIE", "VIRA", "VIRGINIA",
        "VLADYSLAVA", "WARWARA", "WERA", "YELYZAVETA", "ZELIA", "ZLATOSLAVA",
        "ZORA", "ZORIA", "ZUZANA", "ZAKLINA"
    ]

    surnames = ['PODOLSKI', 'DUDA', 'WAWELSKI', 'MAZOWSKI', 'CHORTEKS', 'OLSZA', \
                'SMITH', 'SCHMIDT', 'KOWALSKI', 'KALVIS', 'KOVAC', 'FABER', \
                'NIEWIEM', 'KOVOTEPEC', 'TUMELO', 'HERRERO', 'SEPPA', 'HAMPO', \
                'SMID', 'IMANI', 'BAGGINS', 'TUK', 'GAMGEE', 'BRANDYBUCK', \
                'BOLGER', 'DELVING', 'BANKS', 'TOOK', 'ATREIDES', 'HARKONNEN', \
                'CORRINO', 'TUEK', 'KYNES', 'MAPES', 'DE VRIES', 'HALLECK', \
                'IDAHO', 'HAWAT', 'IRONFUNDERSON', 'VIMES', 'GARLIC', \
                'WETHERWAX', 'OGG', 'STOLAT', 'SPULDING', 'VETINARI', \
                'HERBATA', 'BIGOS', 'WIERZCHON', 'KOZLOWSKI', 'NUTTER']

    def __init__(self, anonymizer, dbfile):
        super().__init__(anonymizer)

        # cache name substitutions
        self.cache = LRUCache(maxsize=100)

        # handle DB
        self.dbfile = dbfile
        self.cursor = None
        self.conn = None
        self.reinitialize()

    def reinitialize(self):
        self.conn = sqlite3.connect(self.dbfile)
        self.cursor = self.conn.cursor()
        self.cursor.execute("CREATE TABLE IF NOT EXISTS mapping " \
                            "(anonID TEXT, fakeID TEXT)" )
        self.cursor.execute("CREATE INDEX IF NOT EXISTS ai ON mapping(anonID)")
        self.conn.commit()

    def process(self):

        # main part
        ret = self.anonymizer.RemovePrivateTags()
        if not ret:
            return False
        ret = self.anonymizer.BasicApplicationLevelConfidentialityProfile(True)
        if not ret:
            return False

        # reinvent PatientName
        f = self.anonymizer.GetFile()
        ds = f.GetDataSet()

        # this fields for some reason have the space on the rightmost position, so rstrip()
        pid = str((ds.GetDataElement(gdcm.Tag(0x10,
                                              0x20))).GetValue()).rstrip()
        psex = str((ds.GetDataElement(gdcm.Tag(0x10,
                                               0x40))).GetValue()).rstrip()

        #print(f"pid={pid}, psex={psex}")
        pName = self._findAlias(pid, psex)

        return self.anonymizer.Replace(gdcm.Tag(0x10, 0x10), pName)

    def _findAlias(self, pid, psex):

        #L1 cache to easy find most recent substitutions
        name = self.cache.get(pid)
        if name is not None:
            return name

        #then seek in the DB
        self.cursor.execute("SELECT fakeID FROM mapping " \
                            "WHERE anonID = ?", (pid, ))
        ret = self.cursor.fetchone()
        if ret is None:
            # previously unseen
            name = radAnonimizer._generateName(psex)
            self.cursor.execute("INSERT INTO mapping VALUES (?,?)",
                                (pid, name))
            self.conn.commit()
        else:
            # seen
            name = ret[0]

        # one way or another: cache it
        self.cache[pid] = name
        return name

    def _generateName(sex):
        #print(f"sex=|{sex}|")
        if sex == 'F':
            Name = random.choice(radAnonimizer.femalenames)
        else:
            Name = random.choice(radAnonimizer.malenames)

        Surname = random.choice(radAnonimizer.surnames)
        PatientName = "%s^%s^^^" % (Surname, Name)
        return PatientName

    def finalize(self):
        self.conn.close()
예제 #15
0
        for transaction in find_transactions:
            amount = transaction.get('amount')

            if amount is not None:
                amount_value = amount.get('amount')
                if amount_value is not None:
                    amount_value = round(amount_value / 100)
                amount['amount'] = amount_value

            smart_picks_for_entries[transaction['metaData']
                                    ['entryId']] = json.dumps(amount)

        for entry in entries:
            if 'userId' in entry and bot_ids.get(entry['userId']) is None:
                default_pool = all_pool_ids.get(str(entry.get("poolId"))) or {}
                default_event = all_event_ids.get(str(
                    entry.get("eventId"))) or {}
                meta_data = all_remote_ids.get(
                    default_event.get("remoteId")) or all_remote_ids.get(
                        str(default_event.get("_id")))
                meta_data_str = None
                if meta_data is not None:
                    if meta_data.get('roundType') is None:
                        meta_data['roundType'] = 'N/A'
                    if meta_data.get('round') is None and default_event.get(
                            "name") is not None:
                        event_name_round = default_event.get(
                            "name").split()[-1].strip()
                        if event_name_round.isdigit():
                            meta_data['round'] = event_name_round
예제 #16
0
class GMusicLibraryProvider(backend.LibraryProvider):
    root_directory = Ref.directory(uri="gmusic:directory",
                                   name="Google Play Music")

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # tracks, albums, and artists here refer to what is explicitly
        # in our library.
        self.tracks = {}
        self.albums = {}
        self.artists = {}

        # aa_* caches are *only* used for temporary objects. Library
        # objects will never make it here.
        self.aa_artists = LRUCache(1024)
        self.aa_tracks = LRUCache(1024)
        self.aa_albums = LRUCache(1024)

        self._radio_stations_in_browse = self.backend.config["gmusic"][
            "radio_stations_in_browse"]
        self._radio_stations_count = self.backend.config["gmusic"][
            "radio_stations_count"]
        self._radio_tracks_count = self.backend.config["gmusic"][
            "radio_tracks_count"]

        self._top_tracks_count = self.backend.config["gmusic"][
            "top_tracks_count"]

        # Setup the root of library browsing.
        self._root = [
            Ref.directory(uri="gmusic:album", name="Albums"),
            Ref.directory(uri="gmusic:artist", name="Artists"),
            Ref.directory(uri="gmusic:track", name="Tracks"),
        ]

        if self._radio_stations_in_browse:
            self._root.append(Ref.directory(uri="gmusic:radio", name="Radios"))

    @property
    def all_access(self):
        return self.backend.session.all_access

    def _browse_tracks(self):
        tracks = list(self.tracks.values())
        tracks.sort(key=lambda ref: ref.name)
        refs = []
        for track in tracks:
            refs.append(track_to_ref(track))
        return refs

    def _browse_albums(self):
        refs = []
        for album in self.albums.values():
            refs.append(album_to_ref(album))
        refs.sort(key=lambda ref: ref.name)
        return refs

    def _browse_album(self, uri):
        refs = []
        for track in self._lookup_album(uri):
            refs.append(track_to_ref(track, True))
        return refs

    def _browse_artists(self):
        refs = []
        for artist in self.artists.values():
            refs.append(artist_to_ref(artist))
        refs.sort(key=lambda ref: ref.name)
        return refs

    def _browse_artist(self, uri):
        refs = []
        for album in self._get_artist_albums(uri):
            refs.append(album_to_ref(album))
            refs.sort(key=lambda ref: ref.name)
        if len(refs) > 0:
            refs.insert(0, Ref.directory(uri=uri + ":all", name="All Tracks"))
            is_all_access = uri.startswith("gmusic:artist:A")
            if is_all_access:
                refs.insert(1,
                            Ref.directory(uri=uri + ":top", name="Top Tracks"))
            return refs
        else:
            # Show all tracks if no album is available
            return self._browse_artist_all_tracks(uri)

    def _browse_artist_all_tracks(self, uri):
        artist_uri = ":".join(uri.split(":")[:3])
        refs = []
        tracks = self._lookup_artist(artist_uri, True)
        for track in tracks:
            refs.append(track_to_ref(track))
        return refs

    def _browse_artist_top_tracks(self, uri):
        artist_uri = ":".join(uri.split(":")[:3])
        refs = []
        tracks = self._get_artist_top_tracks(artist_uri)
        for track in tracks:
            refs.append(track_to_ref(track))
        return refs

    def _browse_radio_stations(self, uri):
        stations = self.backend.session.get_radio_stations(
            self._radio_stations_count)
        # create Ref objects
        refs = []
        for station in stations:
            refs.append(
                Ref.directory(uri="gmusic:radio:" + station["id"],
                              name=station["name"]))
        return refs

    def _browse_radio_station(self, uri):
        station_id = uri.split(":")[2]
        tracks = self.backend.session.get_station_tracks(
            station_id, self._radio_tracks_count)

        # create Ref objects
        refs = []
        for track in tracks:
            mopidy_track = self._to_mopidy_track(track)
            self.aa_tracks[mopidy_track.uri] = mopidy_track
            refs.append(track_to_ref(mopidy_track))
        return refs

    def browse(self, uri):
        logger.debug("browse: %s", str(uri))
        if not uri:
            return []
        if uri == self.root_directory.uri:
            return self._root

        parts = uri.split(":")

        # tracks
        if uri == "gmusic:track":
            return self._browse_tracks()

        # albums
        if uri == "gmusic:album":
            return self._browse_albums()

        # a single album
        # uri == 'gmusic:album:album_id'
        if len(parts) == 3 and parts[1] == "album":
            return self._browse_album(uri)

        # artists
        if uri == "gmusic:artist":
            return self._browse_artists()

        # a single artist
        # uri == 'gmusic:artist:artist_id'
        if len(parts) == 3 and parts[1] == "artist":
            return self._browse_artist(uri)

        # all tracks of a single artist
        # uri == 'gmusic:artist:artist_id:all'
        if len(parts) == 4 and parts[1] == "artist" and parts[3] == "all":
            return self._browse_artist_all_tracks(uri)

        # top tracks of a single artist
        # uri == 'gmusic:artist:artist_id:top'
        if len(parts) == 4 and parts[1] == "artist" and parts[3] == "top":
            return self._browse_artist_top_tracks(uri)

        # all radio stations
        if uri == "gmusic:radio":
            return self._browse_radio_stations(uri)

        # a single radio station
        # uri == 'gmusic:radio:station_id'
        if len(parts) == 3 and parts[1] == "radio":
            return self._browse_radio_station(uri)

        logger.debug("Unknown uri for browse request: %s", uri)

        return []

    def lookup(self, uri):
        if uri.startswith("gmusic:track:"):
            return self._lookup_track(uri)
        elif uri.startswith("gmusic:album:"):
            return self._lookup_album(uri)
        elif uri.startswith("gmusic:artist:"):
            return self._lookup_artist(uri)
        else:
            return []

    def _lookup_track(self, uri):
        is_all_access = uri.startswith("gmusic:track:T")

        try:
            return [self.tracks[uri]]
        except KeyError:
            logger.debug(f"Track {uri!r} is not a library track")
            pass

        if is_all_access and self.all_access:
            track = self.aa_tracks.get(uri)
            if track:
                return [track]
            track = self.backend.session.get_track_info(uri.split(":")[2])
            if track is None:
                logger.warning(f"Could not find track {uri!r}")
                return []
            if "artistId" not in track:
                logger.warning(f"Failed to lookup {uri!r}")
                return []
            mopidy_track = self._to_mopidy_track(track)
            self.aa_tracks[mopidy_track.uri] = mopidy_track
            return [mopidy_track]
        else:
            return []

    def _lookup_album(self, uri):
        is_all_access = uri.startswith("gmusic:album:B")
        if self.all_access and is_all_access:
            tracks = self.aa_albums.get(uri)
            if tracks:
                return tracks
            album = self.backend.session.get_album_info(uri.split(":")[2],
                                                        include_tracks=True)
            if album and album.get("tracks"):
                tracks = [
                    self._to_mopidy_track(track) for track in album["tracks"]
                ]
                for track in tracks:
                    self.aa_tracks[track.uri] = track
                tracks = sorted(tracks, key=lambda t: (t.disc_no, t.track_no))
                self.aa_albums[uri] = tracks
                return tracks

            logger.warning(f"Failed to lookup all access album {uri!r}")

        # Even if the album has an all access ID, we need to look it
        # up here (as a fallback) because purchased tracks can have a
        # store ID, but only show up in your library.
        try:
            album = self.albums[uri]
        except KeyError:
            logger.debug(f"Failed to lookup {uri!r}")
            return []

        tracks = self._find_exact(
            dict(
                album=album.name,
                artist=[artist.name for artist in album.artists],
                date=album.date,
            )).tracks
        return sorted(tracks, key=lambda t: (t.disc_no, t.track_no))

    def _get_artist_top_tracks(self, uri):
        is_all_access = uri.startswith("gmusic:artist:A")
        artist_id = uri.split(":")[2]

        if not is_all_access:
            logger.debug("Top Tracks not available for non-all-access artists")
            return []

        artist_info = self.backend.session.get_artist_info(
            artist_id,
            include_albums=False,
            max_top_tracks=self._top_tracks_count,
            max_rel_artist=0,
        )
        top_tracks = []

        for track_dict in artist_info["topTracks"]:
            top_tracks.append(self._to_mopidy_track(track_dict))

        return top_tracks

    def _get_artist_albums(self, uri):
        is_all_access = uri.startswith("gmusic:artist:A")

        artist_id = uri.split(":")[2]
        if is_all_access:
            # all access
            artist_infos = self.backend.session.get_artist_info(
                artist_id, max_top_tracks=0, max_rel_artist=0)
            if artist_infos is None or "albums" not in artist_infos:
                return []
            albums = []
            for album in artist_infos["albums"]:
                albums.append(
                    self._aa_search_album_to_mopidy_album({"album": album}))
            return albums
        elif self.all_access and artist_id in self.aa_artists:
            albums = self._get_artist_albums("gmusic:artist:%s" %
                                             self.aa_artists[artist_id])
            if len(albums) > 0:
                return albums
            # else fall back to non aa albums
        if uri in self.artists:
            artist = self.artists[uri]
            return [
                album for album in self.albums.values()
                if artist in album.artists
            ]
        else:
            logger.debug(f"No albums available for artist {uri!r}")
            return []

    def _lookup_artist(self, uri, exact_match=False):
        def sorter(track):
            return (
                track.album.date,
                track.album.name,
                track.disc_no,
                track.track_no,
            )

        if self.all_access:
            try:
                all_access_id = self.aa_artists[uri.split(":")[2]]
                artist_infos = self.backend.session.get_artist_info(
                    all_access_id, max_top_tracks=0, max_rel_artist=0)
                if not artist_infos or not artist_infos["albums"]:
                    logger.warning(f"Failed to lookup {artist_infos}!r")
                tracks = [
                    self._lookup_album("gmusic:album:" + album["albumId"])
                    for album in artist_infos["albums"]
                ]
                tracks = reduce(lambda a, b: (a + b), tracks)
                return sorted(tracks, key=sorter)
            except KeyError:
                pass
        try:
            artist = self.artists[uri]
        except KeyError:
            logger.debug(f"Failed to lookup {uri!r}")
            return []

        tracks = self._find_exact(dict(artist=artist.name)).tracks
        if exact_match:
            tracks = filter(lambda t: artist in t.artists, tracks)
        return sorted(tracks, key=sorter)

    def refresh(self, uri=None):
        logger.info("Refreshing library")

        self.tracks = {}
        self.albums = {}
        self.artists = {}

        album_tracks = {}
        for track in self.backend.session.get_all_songs():
            mopidy_track = self._to_mopidy_track(track)

            self.tracks[mopidy_track.uri] = mopidy_track
            self.albums[mopidy_track.album.uri] = mopidy_track.album

            # We don't care about the order because we're just using
            # this as a temporary variable to grab the proper album
            # artist out of the album.
            if mopidy_track.album.uri not in album_tracks:
                album_tracks[mopidy_track.album.uri] = []

            album_tracks[mopidy_track.album.uri].append(mopidy_track)

        # Yes, this is awful. No, I don't have a better solution. Yes,
        # I'm annoyed at Google for not providing album artist IDs.
        for album in self.albums.values():
            artist_found = False
            for album_artist in album.artists:
                for track in album_tracks[album.uri]:
                    for artist in track.artists:
                        if album_artist.name == artist.name:
                            artist_found = True
                            self.artists[artist.uri] = artist

            if not artist_found:
                for artist in album.artists:
                    self.artists[artist.uri] = artist

        logger.info("Loaded "
                    f"{len(self.artists)} artists, "
                    f"{len(self.albums)} albums, "
                    f"{len(self.tracks)} tracks from Google Play Music")

    def search(self, query=None, uris=None, exact=False):
        if exact:
            return self._find_exact(query=query, uris=uris)

        lib_tracks, lib_artists, lib_albums = self._search_library(query, uris)

        if query:
            aa_tracks, aa_artists, aa_albums = self._search(query, uris)
            for aa_artist in aa_artists:
                lib_artists.add(aa_artist)

            for aa_album in aa_albums:
                lib_albums.add(aa_album)

            lib_tracks = set(lib_tracks)

            for aa_track in aa_tracks:
                lib_tracks.add(aa_track)

        return SearchResult(
            uri="gmusic:search",
            tracks=lib_tracks,
            artists=lib_artists,
            albums=lib_albums,
        )

    def _find_exact(self, query=None, uris=None):
        # Find exact can only be done on gmusic library,
        # since one can't filter all access searches
        lib_tracks, lib_artists, lib_albums = self._search_library(query, uris)

        return SearchResult(
            uri="gmusic:search",
            tracks=lib_tracks,
            artists=lib_artists,
            albums=lib_albums,
        )

    def _search(self, query=None, uris=None):
        for (field, values) in query.items():
            if not hasattr(values, "__iter__"):
                values = [values]

            # Since gmusic does not support search filters, just search for the
            # first 'searchable' filter
            if field in [
                    "track_name", "album", "artist", "albumartist", "any"
            ]:
                logger.info(f"Searching Google Play Music for: {values[0]}")
                res = self.backend.session.search(values[0], max_results=50)
                if res is None:
                    return [], [], []

                albums = [
                    self._aa_search_album_to_mopidy_album(album_res)
                    for album_res in res["album_hits"]
                ]
                artists = [
                    self._aa_search_artist_to_mopidy_artist(artist_res)
                    for artist_res in res["artist_hits"]
                ]
                tracks = [
                    self._aa_search_track_to_mopidy_track(track_res)
                    for track_res in res["song_hits"]
                ]

                return tracks, artists, albums

        return [], [], []

    def _search_library(self, query=None, uris=None):
        if query is None:
            query = {}
        self._validate_query(query)
        result_tracks = self.tracks.values()

        for (field, values) in query.items():
            if not isinstance(values, list):
                values = [values]
            # FIXME this is bound to be slow for large libraries
            for value in values:
                if field == "track_no":
                    q = self._convert_to_int(value)
                else:
                    q = value.strip().lower()

                def uri_filter(track):
                    return q in track.uri.lower()

                def track_name_filter(track):
                    return q in track.name.lower()

                def album_filter(track):
                    return q in getattr(track, "album", Album()).name.lower()

                def artist_filter(track):
                    return any(
                        q in a.name.lower()
                        for a in track.artists) or albumartist_filter(track)

                def albumartist_filter(track):
                    album_artists = getattr(track, "album", Album()).artists
                    return any(q in a.name.lower() for a in album_artists)

                def track_no_filter(track):
                    return track.track_no == q

                def date_filter(track):
                    return track.date and track.date.startswith(q)

                def any_filter(track):
                    return any([
                        uri_filter(track),
                        track_name_filter(track),
                        album_filter(track),
                        artist_filter(track),
                        albumartist_filter(track),
                        date_filter(track),
                    ])

                if field == "uri":
                    result_tracks = list(filter(uri_filter, result_tracks))
                elif field == "track_name":
                    result_tracks = list(
                        filter(track_name_filter, result_tracks))
                elif field == "album":
                    result_tracks = list(filter(album_filter, result_tracks))
                elif field == "artist":
                    result_tracks = list(filter(artist_filter, result_tracks))
                elif field == "albumartist":
                    result_tracks = list(
                        filter(albumartist_filter, result_tracks))
                elif field == "track_no":
                    result_tracks = list(filter(track_no_filter,
                                                result_tracks))
                elif field == "date":
                    result_tracks = list(filter(date_filter, result_tracks))
                elif field == "any":
                    result_tracks = list(filter(any_filter, result_tracks))
                else:
                    raise LookupError("Invalid lookup field: %s" % field)

        result_artists = set()
        result_albums = set()
        for track in result_tracks:
            result_artists |= track.artists
            result_albums.add(track.album)

        return result_tracks, result_artists, result_albums

    def _validate_query(self, query):
        for (_, values) in query.items():
            if not values:
                raise LookupError("Missing query")
            for value in values:
                if not value:
                    raise LookupError("Missing query")

    def _to_mopidy_track(self, song):
        track_id = song.get("id", song.get("nid"))
        if track_id is None:
            raise ValueError
        if track_id[0] != "T" and "-" not in track_id:
            track_id = "T" + track_id
        return Track(
            uri="gmusic:track:" + track_id,
            name=song["title"],
            artists=[self._to_mopidy_artist(song)],
            album=self._to_mopidy_album(song),
            track_no=song.get("trackNumber", 1),
            disc_no=song.get("discNumber", 1),
            date=str(song.get("year", 0)),
            length=int(song["durationMillis"]),
            bitrate=self.backend.config["gmusic"]["bitrate"],
        )

    def _to_mopidy_album(self, song):
        name = song.get("album", "")
        artist = self._to_mopidy_album_artist(song)
        date = str(song.get("year", 0))

        album_id = create_id(f"{artist.name}|{name}|{date}")

        uri = "gmusic:album:" + album_id
        return Album(
            uri=uri,
            name=name,
            artists=[artist],
            num_tracks=song.get("totalTrackCount"),
            num_discs=song.get("totalDiscCount"),
            date=date,
        )

    def _to_mopidy_artist(self, song):
        name = song.get("artist", "")
        artist_id = create_id(name)
        uri = "gmusic:artist:" + artist_id
        return Artist(uri=uri, name=name)

    def _to_mopidy_album_artist(self, song):
        name = song.get("albumArtist", "")
        if name.strip() == "":
            name = song.get("artist", "")
        uri = "gmusic:artist:" + create_id(name)
        return Artist(uri=uri, name=name)

    def _aa_search_track_to_mopidy_track(self, search_track):
        track = search_track["track"]

        aa_artist_id = create_id(track["artist"])
        if "artistId" in track:
            aa_artist_id = track["artistId"][0]
        else:
            logger.warning("No artistId for Track %r", track)

        artist = Artist(uri="gmusic:artist:" + aa_artist_id,
                        name=track["artist"])

        album = Album(
            uri="gmusic:album:" + track["albumId"],
            name=track["album"],
            artists=[artist],
            date=str(track.get("year", 0)),
        )

        return Track(
            uri="gmusic:track:" + track["storeId"],
            name=track["title"],
            artists=[artist],
            album=album,
            track_no=track.get("trackNumber", 1),
            disc_no=track.get("discNumber", 1),
            date=str(track.get("year", 0)),
            length=int(track["durationMillis"]),
            bitrate=self.backend.config["gmusic"]["bitrate"],
        )

    def _aa_search_artist_to_mopidy_artist(self, search_artist):
        artist = search_artist["artist"]
        uri = "gmusic:artist:" + artist["artistId"]
        return Artist(uri=uri, name=artist["name"])

    def _aa_search_album_to_mopidy_album(self, search_album):
        album = search_album["album"]
        uri = "gmusic:album:" + album["albumId"]
        name = album["name"]
        artist = self._aa_search_artist_album_to_mopidy_artist_album(album)
        date = str(album.get("year", 0))
        return Album(uri=uri, name=name, artists=[artist], date=date)

    def _aa_search_artist_album_to_mopidy_artist_album(self, album):
        name = album.get("albumArtist", "")
        if name.strip() == "":
            name = album.get("artist", "")
        uri = "gmusic:artist:" + create_id(name)
        return Artist(uri=uri, name=name)

    def _convert_to_int(self, string):
        try:
            return int(string)
        except ValueError:
            return object()
예제 #17
0
class ImagenetModel:
    ''' A class for featurizing images using pre-trained neural nets '''
    def __init__(self,
                 include_top=False,
                 pooling=None,
                 n_channels=None,
                 cache_size=int(1e4),
                 model='inception_v3',
                 weights='imagenet',
                 cache_dir=None,
                 n_objects=None):

        self.include_top = include_top  # determines if used for classification or featurization, TODO separate into two classes?
        self.n_channels = n_channels
        self.n_objects = n_objects
        self.pooling = pooling

        self.failed_urls = set()

        # NOTE: set cache_dir to None to turn off caching
        if cache_dir:
            # create default cache path in the current file dir w/ filename specifying config
            config = [
                f'objects-{NUM_OBJECTS}' if include_top else 'features',
                str(cache_size), model, pooling if pooling else '',
                str(n_channels) if n_channels else ''
            ]
            config_str = '-'.join([c for c in config if c
                                   ])  # filter out empty strings and join w/ -
            cache_fname = f'imagenet-cache-{config_str}.pkl'
            self.cache_path = os.path.join(cache_dir, cache_fname)
            # TODO allow larger cache_size to still load from previous smaller caches
        else:
            self.cache_path = None

        if self.cache_path and os.path.isfile(self.cache_path):
            self.load_cache()
        else:
            self.cache = LRUCache(cache_size)

        if model == 'xception':
            self.model = xception.Xception(weights=weights,
                                           include_top=include_top,
                                           pooling=pooling)
            self.preprocess = xception.preprocess_input
            self.target_size = (299, 299)
            if include_top:
                self.decode = xception.decode_predictions
            else:
                self.output_dim = (n_channels if n_channels else
                                   2048) * (1 if pooling else 10**2)
        elif model == 'inception_v3':
            self.model = inception_v3.InceptionV3(weights=weights,
                                                  include_top=include_top,
                                                  pooling=pooling)
            self.preprocess = inception_v3.preprocess_input
            self.target_size = (299, 299)
            if include_top:
                self.decode = inception_v3.decode_predictions
            else:
                self.output_dim = (n_channels if n_channels else
                                   2048) * (1 if pooling else 8**2)
        elif model == 'mobilenet_v2':
            self.model = mobilenetv2.MobileNetV2(weights=weights,
                                                 include_top=include_top,
                                                 pooling=pooling)
            self.preprocess = mobilenetv2.preprocess_input
            self.target_size = (244, 244)
            if include_top:
                self.decode = mobilenetv2.decode_predictions
            else:
                self.output_dim = (n_channels if n_channels else
                                   1280) * (1 if pooling else 7**2)
        else:
            raise Exception('model option not implemented')

        # NOTE: we force the imagenet model to load in the same scope as the functions using it to avoid tensorflow weirdness
        self.model.predict(np.zeros((1, *self.target_size, 3)))
        logging.info('imagenet loaded')

    def save_cache(self, cache_path=None):
        ''' saves cache of image identifier (url or path) to image features at the given cache path '''
        logging.info('saving cache')
        cache_path = cache_path if cache_path else self.cache_path
        with open(cache_path, 'wb') as pkl_file:
            pickle.dump({
                'cache': self.cache,
                'failed_urls': self.failed_urls
            }, pkl_file)

    def load_cache(self, cache_path=None):
        ''' loads cache of image identifier (url or path) to image features '''
        cache_path = cache_path if cache_path else self.cache_path
        logging.info(f'loading cache from {cache_path}')
        if not os.path.isfile(cache_path):
            logging.error(f'cache file not present at: {cache_path}')
        else:
            with open(cache_path, 'rb') as pkl_file:
                pkl_data = pickle.load(pkl_file)
                self.cache = pkl_data['cache']
                self.failed_urls = pkl_data['failed_urls']

            logging.info(
                f'successfully loaded cache with {len(self.cache)} entries \
                         and failed urls with {len(self.failed_urls)} entries')

    def get_objects_from_url(self, image_url, ignore_failed=True):
        ''' detects objects from image in a url, returns None if url download failed '''
        if image_url not in self.cache:
            # skip if we're ignoring previously failed urls
            if ignore_failed and image_url in self.failed_urls:
                return

            # download image and convert into numpy array
            image_array = image_array_from_url(image_url,
                                               target_size=self.target_size)
            if image_array is None:
                # if url request failed, add to failed set
                self.failed_urls.add(image_url)
                return

            # add a dim if needed
            if image_array.ndim == 3:
                image_array = image_array[None, :, :, :]

            # use the imagenet model to detect the objects in the image and add result to cache
            self.cache[image_url] = self.get_objects(image_array)

        # returned cached result
        return self.cache[image_url]

    def get_objects(self, image_array):
        ''' detects objects in image provided as an array '''
        logging.debug(f'recognizing objects')
        image_array = self.preprocess(image_array)
        objects = self.model.predict(image_array)
        objects = self.decode(objects, top=self.n_objects)[0]
        return {
            obj[1]: obj[2]
            for obj in objects
        }  # objects = [{'object': obj[1], 'score': obj[2]} for obj in objects]

    def get_features_from_paths(self, image_paths):
        ''' takes a list of image filepaths and returns the features resulting from applying the imagenet model to those images '''
        if self.include_top:
            raise Exception(
                'getting features from a classification model with include_top=True is currently not supported'
            )
        # TODO add caching for paths like urls
        images_array = np.array(
            (image_array_from_path(fpath, target_size=self.target_size)
             for fpath in image_paths))
        return self.get_features(images_array)

    def get_features_from_url(self, image_url):
        ''' attempt to download the image at the given url, then return the imagenet features if successful, and None if not '''
        if self.include_top:
            raise Exception(
                'getting features from a classification model with include_top=True is currently not supported'
            )

        if image_url not in self.cache:
            image_array = image_array_from_url(image_url,
                                               target_size=self.target_size)
            if image_array is None:
                self.failed_urls.add(image_url)
                return
            else:
                if image_array.ndim == 3:
                    image_array = image_array[None, :, :, :]
                self.cache[image_url] = self.get_features(image_array)

        return self.cache.get(image_url)

    def get_features_from_url_batch(self, image_urls, ignore_failed=True):
        ''' takes a list of image urls and returns the features resulting from applying the imagenet model to
        successfully downloaded images along with the urls that were successful. Cached values are used when available
        '''
        if self.include_top:
            raise Exception(
                'getting features from a classification model with include_top=True is currently not supported'
            )
        # split urls into new ones and ones that have cached results
        new_urls = image_urls
        cached_urls = []
        # new_urls, cached_urls = partition(lambda x: x in self.cache, image_urls, as_list=True)
        logging.info(f'getting image arrays from {len(image_urls)} urls; \
                     {len(new_urls)} new urls and {len(cached_urls)} cached urls'
                     )
        if cached_urls:
            logging.debug(
                f'loading features for {len(cached_urls)} images from cache')
            if len(cached_urls) == 1:
                cached_image_features = self.cache[cached_urls[0]]
                # print('pre cached dim:', cached_image_features.ndim)
                # if cached_image_features.ndim == 1:
                #     cached_image_features = cached_image_features[None, :]
                # elif cached_image_features.ndim == 3:
                #     assert cached_image_features.shape[:1] == (1, 1)
                #     cached_image_features = cached_image_features[0, :, :]
                # print('post cached dim:', cached_image_features.ndim)
                assert cached_image_features.ndim == 2
            else:
                cached_image_features = np.array(
                    [self.cache[url] for url in cached_urls])
                # print('pre cached dim:', cached_image_features.ndim)
                # if cached_image_features.ndim == 1:
                #     cached_image_features = cached_image_features[None, :]
                # elif cached_image_features.ndim == 3:
                #     assert cached_image_features.shape[:1] == (1, 1)
                #     cached_image_features = cached_image_features[0, :, :]
                # print('cached dim:', cached_image_features.ndim)
                assert cached_image_features.ndim == 2
            # print('cached dim:', cached_image_features.ndim)

        # remove new urls known to fail
        if new_urls and ignore_failed:
            logging.debug(
                f'num new urls before dopping fails: {len(new_urls)}')
            new_urls = list(
                filter(lambda x: x not in self.failed_urls, new_urls))

        if new_urls:
            logging.debug(
                f'computing features for {len(new_urls)} images from urls')
            # attempt to download images and convert to constant-size arrays  # TODO what to do with failed urls, try again, cache failure?
            new_image_arrays = (image_array_from_url(
                url, target_size=self.target_size) for url in new_urls)

            # filter out unsuccessful image urls which output None
            failed_images, downloaded_images = partition(
                lambda x: x[1] is not None,
                zip(new_urls, new_image_arrays),
                as_list=True)

            logging.debug(f'found {len(failed_images)} failed url images')
            logging.info(
                f'successfully downloaded {len(downloaded_images)} url images')
            # add failed urls to list
            logging.debug('saving failed urls to failed set')
            self.failed_urls.update(pair[0] for pair in failed_images)
            # downloaded_images = [(url, img) for url, img in zip(new_urls, new_image_arrays) if img is not None]

            if downloaded_images:
                # unzip any successful url, img pairs and convert data types
                new_urls, new_image_arrays = zip(*downloaded_images)
                new_urls = list(new_urls)
                new_image_arrays = np.array(new_image_arrays)

                logging.debug(
                    f'getting features from image arrays with shape {new_image_arrays.shape}'
                )
                new_image_features = self.get_features(new_image_arrays)
                assert new_image_features.ndim == 2
                logging.debug(
                    f'got features array with shape {new_image_features.shape}'
                )
                # add new image features to cache
                logging.info('saving features to cache')

                self.cache.update(zip(new_urls, new_image_features))

        if cached_urls and new_urls and downloaded_images:
            # print('cached:', cached_image_features.shape)
            # print('new: ', new_image_features.shape)
            logging.debug('cached and new')
            # combine results
            image_features = np.vstack(
                (cached_image_features, new_image_features))
            image_urls = cached_urls + new_urls
        elif cached_urls:
            logging.debug('cached')
            image_features = cached_image_features
            image_urls = cached_urls
        elif new_urls and downloaded_images:
            logging.debug('new')
            image_features = new_image_features
            image_urls = new_urls
        else:
            logging.debug('no new or cached urls')
            return np.array([[]]), []

        return image_features, image_urls

    def get_features(self, images_array):
        ''' takes a batch of images as a 4-d array and returns the (flattened) imagenet features for those images as a 2-d array '''
        if self.include_top:
            raise Exception(
                'getting features from a classification model with include_top=True is currently not supported'
            )

        if images_array.ndim != 4:
            raise Exception(
                'invalid input shape for images_array, expects a 4d array')

        # preprocess and compute image features
        logging.debug(f'preprocessing {images_array.shape[0]} images')
        images_array = self.preprocess(images_array)
        logging.debug(f'computing image features')
        image_features = self.model.predict(images_array)

        # if n_channels is specified, only keep that number of channels
        if self.n_channels:
            logging.debug(f'truncating to first {self.n_channels} channels')
            image_features = image_features.T[:self.n_channels].T

        # reshape output array by flattening each image into a vector of features
        shape = image_features.shape
        return image_features.reshape(shape[0], np.prod(shape[1:]))

    def predict(self, images_array):
        ''' alias for get_features to more closely match scikit-learn interface '''
        return self.get_features(images_array)

    def finetune(self,
                 image_paths,
                 labels,
                 pooling='avg',
                 nclasses=2,
                 batch_size=32,
                 top_layer_epochs=1,
                 frozen_layer_count=249,
                 all_layer_epochs=5,
                 class_weight=None,
                 optimizer='rmsprop'):
        ''' Finetunes the Imagenet model iteratively on a smaller set of images with (potentially) a smaller set of classes.
            First finetunes last layer then freezes bottom N layers and retrains the rest
        '''

        # preprocess images
        images_array = np.array([
            image_array_from_path(fpath, target_size=self.target_size)
            for fpath in image_paths
        ])
        logging.debug(f'preprocessing {images_array.shape[0]} images')
        if images_array.ndim != 4:
            raise Exception(
                'invalid input shape for images_array, expects a 4d array')
        images_array = self.preprocess(images_array)

        # transform labels to categorical variable
        labels = to_categorical(labels)

        # create new model for finetuned classification
        out = self.model.output
        if self.pooling is None:
            out = GlobalAveragePooling2D()(
                out) if pooling == 'avg' else GlobalMaxPooling2D()(out)
        dense = Dense(1024, activation='relu')(out)
        preds = Dense(nclasses, activation='softmax')(dense)
        self.finetune_model = Model(inputs=self.model.input, outputs=preds)

        # freeze all convolutional InceptionV3 layers, retrain top layer
        for layer in self.finetune_model.layers:
            layer.trainable = False
        self.finetune_model.compile(optimizer=optimizer,
                                    loss='categorical_crossentropy')
        self.finetune_model.fit(images_array,
                                np.array(labels),
                                batch_size=batch_size,
                                epochs=top_layer_epochs,
                                class_weight=class_weight)

        # freeze bottom N convolutional layers, retrain top M-N layers (M = total number of layers)
        for layer in self.finetune_model.layers[:frozen_layer_count]:
            layer.trainable = False
        for layer in self.finetune_model.layers[frozen_layer_count:]:
            layer.trainable = True

        # use SGD and low learning rate to prevent catastrophic forgetting in these blocks
        self.finetune_model.compile(optimizer=SGD(lr=0.0001, momentum=0.9),
                                    loss='categorical_crossentropy')
        self.finetune_model.fit(images_array,
                                np.array(labels),
                                batch_size=batch_size,
                                epochs=all_layer_epochs,
                                class_weight=class_weight)

    def finetuned_predict(self, images_array):
        ''' Uses the finetuned model to predict on an image array. Returns array of softmax prediction probabilities 
        '''

        # preprocess images
        images_array = np.array([
            image_array_from_path(fpath, target_size=self.target_size)
            for fpath in image_paths
        ])
        logging.debug(f'preprocessing {images_array.shape[0]} images')
        if images_array.ndim != 4:
            raise Exception(
                'invalid input shape for images_array, expects a 4d array')
        images_array = self.preprocess(images_array)

        return self.finetune_model.predict(images_array)
예제 #18
0
class ApiClientV1(object):
    """
    Thin client binding for API v1.
    """

    def __init__(self, base_uri):
        """
        Construct a new :class:`ClientV1`.

        Args:
            base_uri (str): Base URI of the MLStorage server, e.g.,
                "http://example.com".
        """
        base_uri = base_uri.rstrip('/')
        self._base_uri = base_uri
        self._storage_dir_cache = LRUCache(128)

    def _update_storage_dir_cache(self, doc):
        self._storage_dir_cache[doc['id']] = doc['storage_dir']

    @property
    def base_uri(self):
        """Get the base URI of the MLStorage server."""
        return self._base_uri

    def do_request(self, method, endpoint, **kwargs):
        """
        Do `method` request against given `endpoint`.

        Args:
            method (str): The HTTP request method.
            endpoint (str): The endpoint of the API, should start with a
                slash "/".  For example, "/_query".
            \**kwargs: Arguments to be passed to :func:`requests.request`.

        Returns:
            The response object.
        """
        uri = self.base_uri + '/v1' + endpoint
        resp = requests.request(method, uri, **kwargs)
        if resp.status_code != 200:
            raise RuntimeError('HTTP error {}: {}'.
                               format(resp.status_code, resp.text))
        return resp

    def query(self, filter=None, skip=0, limit=10):
        ret = self.do_request(
            'POST', '/_query?skip={}&limit={}'.format(skip, limit),
            json=filter or {}).json()
        for doc in ret:
            self._update_storage_dir_cache(doc)
        return ret

    def get(self, id):
        id = validate_experiment_id(id)
        ret = self.do_request('GET', '/_get/{}'.format(id)).json()
        self._update_storage_dir_cache(ret)
        return ret

    def heartbeat(self, id):
        id = validate_experiment_id(id)
        return self.do_request(
            'POST', '/_heartbeat/{}'.format(id), data=b'').json()

    def create(self, name, doc_fields=None):
        doc_fields = dict(doc_fields or ())
        doc_fields['name'] = name
        doc_fields = validate_experiment_doc(doc_fields)
        ret = self.do_request('POST', '/_create', json=doc_fields).json()
        self._update_storage_dir_cache(ret)
        return ret

    def update(self, id, doc_fields):
        id = validate_experiment_id(id)
        doc_fields = validate_experiment_doc(dict(doc_fields))
        ret = self.do_request(
            'POST', '/_update/{}'.format(id), json=doc_fields).json()
        self._update_storage_dir_cache(ret)
        return ret

    def set_finished(self, id, status, doc_fields):
        id = validate_experiment_id(id)
        doc_fields = dict(doc_fields or ())
        doc_fields['status'] = status
        doc_fields = validate_experiment_doc(doc_fields)
        ret = self.do_request(
            'POST', '/_set_finished/{}'.format(id), json=doc_fields).json()
        self._update_storage_dir_cache(ret)
        return ret

    def delete(self, id):
        id = validate_experiment_id(id)
        ret = self.do_request(
            'POST', '/_delete/{}'.format(id), data=b'').json()
        for i in ret:
            self._storage_dir_cache.pop(i, None)
        return ret

    def get_storage_dir(self, id):
        id = str(validate_experiment_id(id))
        storage_dir = self._storage_dir_cache.get(id, None)
        if storage_dir is None:
            doc = self.get(id)
            storage_dir = doc['storage_dir']
        return storage_dir

    def getfile(self, id, path):
        id = str(validate_experiment_id(id))
        path = validate_relpath(path)
        return self.do_request(
            'GET', '/_getfile/{}/{}'.format(id, path)).content
예제 #19
0
class MLStorageClient(object):
    """
    Client binding for MLStorage Server API v1.
    """
    def __init__(self, uri: str):
        """
        Construct a new :class:`ClientV1`.

        Args:
            uri: Base URI of the MLStorage server, e.g., "http://example.com".
        """
        uri = uri.rstrip('/')
        self._uri = uri
        self._storage_dir_cache = LRUCache(128)

    def _update_storage_dir_cache(self, doc):
        self._storage_dir_cache[doc['_id']] = doc['storage_dir']

    @property
    def uri(self) -> str:
        """Get the base URI of the MLStorage server."""
        return self._uri

    def do_request(self,
                   method: str,
                   endpoint: str,
                   decode_json: bool = True,
                   **kwargs) -> Union[requests.Response, Any]:
        """
        Send request of HTTP `method` to given `endpoint`.

        Args:
            method: The HTTP request method.
            endpoint: The endpoint of the API, should start with a slash "/".
                For example, "/_query".
            decode_json: Whether or not to decode the response body as JSON?
            \\**kwargs: Arguments to be passed to :func:`requests.request`.

        Returns:
            The response object if ``decode_json = False``, or the decoded
            JSON object.
        """
        uri = f'{self.uri}/v1{endpoint}'
        if 'json' in kwargs:
            json_obj = kwargs.pop('json')
            json_str = json_dumps(json_obj)
            kwargs['data'] = json_str
            kwargs.setdefault('headers', {})
            kwargs['headers']['Content-Type'] = 'application/json'

        resp = requests.request(method, uri, **kwargs)
        resp.raise_for_status()

        if decode_json:
            content_type = resp.headers.get('content-type') or ''
            content_type = content_type.split(';', 1)[0]
            if content_type != 'application/json':
                raise IOError(f'The response from {uri} is not JSON: '
                              f'HTTP code is {resp.status_code}')
            resp = json_loads(resp.content)

        return resp

    def query(self,
              filter: Optional[FilterType] = None,
              sort: Optional[str] = None,
              skip: int = 0,
              limit: Optional[int] = None) -> List[DocumentType]:
        """
        Query experiment documents according to the `filter`.

        Args:
            filter: The filter dict.
            sort: Sort by which field, a string matching the pattern
                ``[+/-]<field>``.  "+" means ASC order, while "-" means
                DESC order.  For example, "start_time", "+start_time" and
                "-stop_time".
            skip: The number of records to skip.
            limit: The maximum number of records to retrieve.

        Returns:
            The documents of the matched experiments.
        """
        uri = f'/_query?skip={skip}'
        if sort is not None:
            uri += f'&sort={urlquote(sort)}'
        if limit is not None:
            uri += f'&limit={limit}'
        ret = self.do_request('POST', uri, json=filter or {})
        for doc in ret:
            self._update_storage_dir_cache(doc)
        return ret

    def get(self, id: IdType) -> DocumentType:
        """
        Get the document of an experiment by its `id`.

        Args:
            id: The id of the experiment.

        Returns:
            The document of the retrieved experiment.
        """
        ret = self.do_request('GET', f'/_get/{id}')
        self._update_storage_dir_cache(ret)
        return ret

    def heartbeat(self, id: IdType) -> None:
        """
        Send heartbeat packet for the experiment `id`.

        Args:
            id: The id of the experiment.
        """
        self.do_request('POST', f'/_heartbeat/{id}', data=b'')

    def create(self, doc_fields: DocumentType) -> DocumentType:
        """
        Create an experiment.

        Args:
            doc_fields: The document fields of the new experiment.

        Returns:
            The document of the created experiment.
        """
        doc_fields = dict(doc_fields)
        ret = self.do_request('POST', '/_create', json=doc_fields)
        self._update_storage_dir_cache(ret)
        return ret

    def update(self, id: IdType, doc_fields: DocumentType) -> DocumentType:
        """
        Update the document of an experiment.

        Args:
            id: ID of the experiment.
            doc_fields: The fields to be updated.

        Returns:
            The document of the updated experiment.
        """
        ret = self.do_request('POST', f'/_update/{id}', json=doc_fields)
        self._update_storage_dir_cache(ret)
        return ret

    def add_tags(self, id: IdType, tags: Iterable[str]) -> DocumentType:
        """
        Add tags to an experiment document.

        Args:
            id: ID of the experiment.
            tags: New tags to be added.

        Returns:
            The document of the updated experiment.
        """
        old_doc = self.get(id)
        new_tags = old_doc.get('tags', [])
        for tag in tags:
            if tag not in new_tags:
                new_tags.append(tag)
        return self.update(id, {'tags': new_tags})

    def delete(self, id: IdType) -> List[IdType]:
        """
        Delete an experiment.

        Args:
            id: ID of the experiment.

        Returns:
            List of deleted experiment IDs.
        """
        ret = self.do_request('POST', f'/_delete/{id}', data=b'')
        for i in ret:
            self._storage_dir_cache.pop(i, None)
        return ret

    def set_finished(
            self,
            id: IdType,
            status: str,
            doc_fields: Optional[DocumentType] = None) -> DocumentType:
        """
        Set the status of an experiment.

        Args:
            id: ID of the experiment.
            status: The new status, one of {"RUNNING", "COMPLETED", "FAILED"}.
            doc_fields: Optional new document fields to be set.

        Returns:
            The document of the updated experiment.
        """
        doc_fields = dict(doc_fields or ())
        doc_fields['status'] = status
        ret = self.do_request('POST', f'/_set_finished/{id}', json=doc_fields)
        self._update_storage_dir_cache(ret)
        return ret

    def get_storage_dir(self, id: IdType) -> str:
        """
        Get the storage directory of an experiment.

        Args:
            id: ID of the experiment.

        Returns:
            The storage directory of the experiment.
        """
        id = str(id)
        storage_dir = self._storage_dir_cache.get(id, None)
        if storage_dir is None:
            doc = self.get(id)
            storage_dir = doc['storage_dir']
        return storage_dir

    def get_file(self, id: IdType, path: str) -> bytes:
        """
        Get the content of a file in the storage directory of an experiment.

        Args:
            id: ID of the experiment.
            path: Relative path of the file.

        Returns:
            The file content.
        """
        id = str(id)
        path = normalize_relpath(path)
        return self.do_request('GET',
                               f'/_getfile/{id}/{path}',
                               decode_json=False).content
class ShareUsageTracker(Tracker):
    """!
    The ShareUsageTracker is used by the ShareUsageElement to determine
    whether to put evidence into a bundle to be sent to the 51Degrees
    Share Usage service.
    """
    def __init__(self, size=100, interval=1000):
        """!
        Constructor for ShareUsageTracker
        @type size: int
        @param size: size of the share usage lru cache
        @type interval: int
        @param interval: how often to put evidence into the cache

        """

        self.interval = interval

        self.cache = LRUCache(maxsize=size)

    def match(self, key, value):
        """!
        The track method calls the dataKeyedCache get method,
        if it receives a result it sends it onto a match function
        
        @type key : cache key to run through tracker
        @rtype bool 
        @return result of tracking

        """

        difference = time.time() - value

        if difference > self.interval:
            self.set_cache_value(key, value)
            return True
        else:
            return False

    def get_cache_value(self, cachekey):
        """!
        Get data from the cache
        @type key : string
        @param key : The cache key to lookup
        @type value : mixed
        @param key : None , or the stored data
        """

        return self.cache.get(cachekey)

    def set_cache_value(self, cachekey, value=None):
        """!
        Place data in the cache
        @type key : string
        @param key : The cache key to store data under
        @type value : mixed
        @param key : Not used here as value is set to the date

        """

        self.cache.__setitem__(cachekey, time.time())
class MetricCassandraRepository(abstract_repository.AbstractCassandraRepository
                                ):
    def __init__(self):
        super(MetricCassandraRepository, self).__init__()

        self._lock = threading.RLock()

        LOG.debug("prepare cql statements...")

        self._measurement_insert_stmt = self._session.prepare(
            MEASUREMENT_INSERT_CQL)
        self._measurement_insert_stmt.is_idempotent = True

        self._measurement_update_stmt = self._session.prepare(
            MEASUREMENT_UPDATE_CQL)
        self._measurement_update_stmt.is_idempotent = True

        self._metric_insert_stmt = self._session.prepare(METRICS_INSERT_CQL)
        self._metric_insert_stmt.is_idempotent = True

        self._metric_update_stmt = self._session.prepare(METRICS_UPDATE_CQL)
        self._metric_update_stmt.is_idempotent = True

        self._dimension_stmt = self._session.prepare(DIMENSION_INSERT_CQL)
        self._dimension_stmt.is_idempotent = True

        self._dimension_metric_stmt = self._session.prepare(
            DIMENSION_METRIC_INSERT_CQL)
        self._dimension_metric_stmt.is_idempotent = True

        self._metric_dimension_stmt = self._session.prepare(
            METRIC_DIMENSION_INSERT_CQL)
        self._metric_dimension_stmt.is_idempotent = True

        self._retrieve_metric_dimension_stmt = self._session.prepare(
            RETRIEVE_METRIC_DIMENSION_CQL)

        self._metric_batch = MetricBatch(self._cluster.metadata,
                                         self._cluster.load_balancing_policy,
                                         self._max_batches)

        self._metric_id_cache = LRUCache(self._cache_size)
        self._dimension_cache = LRUCache(self._cache_size)
        self._metric_dimension_cache = LRUCache(self._cache_size)

        self._load_dimension_cache()
        self._load_metric_dimension_cache()

    def process_message(self, message):
        (dimensions, metric_name, region, tenant_id, time_stamp, value,
         value_meta) = parse_measurement_message(message)

        with self._lock:
            dim_names = []
            dim_list = []
            for name in sorted(dimensions.iterkeys()):
                dim_list.append('%s\t%s' % (name, dimensions[name]))
                dim_names.append(name)

            hash_string = '%s\0%s\0%s\0%s' % (region, tenant_id, metric_name,
                                              '\0'.join(dim_list))
            metric_id = hashlib.sha1(hash_string.encode('utf8')).hexdigest()

            metric = Metric(id=metric_id,
                            region=region,
                            tenant_id=tenant_id,
                            name=metric_name,
                            dimension_list=dim_list,
                            dimension_names=dim_names,
                            time_stamp=time_stamp,
                            value=value,
                            value_meta=json.dumps(value_meta,
                                                  ensure_ascii=False))

            id_bytes = bytearray.fromhex(metric.id)
            if self._metric_id_cache.get(metric.id, None):
                measurement_bound_stmt = self._measurement_update_stmt.bind(
                    (self._retention, metric.value, metric.value_meta,
                     id_bytes, metric.time_stamp))
                self._metric_batch.add_measurement_query(
                    measurement_bound_stmt)

                metric_update_bound_stmt = self._metric_update_stmt.bind(
                    (self._retention, id_bytes, metric.time_stamp,
                     metric.region, metric.tenant_id, metric.name,
                     metric.dimension_list, metric.dimension_names))
                self._metric_batch.add_metric_query(metric_update_bound_stmt)

                return metric

            self._metric_id_cache[metric.id] = metric.id

            metric_insert_bound_stmt = self._metric_insert_stmt.bind(
                (self._retention, id_bytes, metric.time_stamp,
                 metric.time_stamp, metric.region, metric.tenant_id,
                 metric.name, metric.dimension_list, metric.dimension_names))
            self._metric_batch.add_metric_query(metric_insert_bound_stmt)

            for dim in metric.dimension_list:
                (name, value) = dim.split('\t')
                dim_key = self._get_dimnesion_key(metric.region,
                                                  metric.tenant_id, name,
                                                  value)
                if not self._dimension_cache.get(dim_key, None):
                    dimension_bound_stmt = self._dimension_stmt.bind(
                        (metric.region, metric.tenant_id, name, value))
                    self._metric_batch.add_dimension_query(
                        dimension_bound_stmt)
                    self._dimension_cache[dim_key] = dim_key

                metric_dim_key = self._get_metric_dimnesion_key(
                    metric.region, metric.tenant_id, metric.name, name, value)
                if not self._metric_dimension_cache.get(metric_dim_key, None):
                    dimension_metric_bound_stmt = self._dimension_metric_stmt.bind(
                        (metric.region, metric.tenant_id, name, value,
                         metric.name))
                    self._metric_batch.add_dimension_metric_query(
                        dimension_metric_bound_stmt)

                    metric_dimension_bound_stmt = self._metric_dimension_stmt.bind(
                        (metric.region, metric.tenant_id, metric.name, name,
                         value))
                    self._metric_batch.add_metric_dimension_query(
                        metric_dimension_bound_stmt)

                    self._metric_dimension_cache[
                        metric_dim_key] = metric_dim_key

            measurement_insert_bound_stmt = self._measurement_insert_stmt.bind(
                (self._retention, metric.value, metric.value_meta,
                 metric.region, metric.tenant_id, metric.name,
                 metric.dimension_list, id_bytes, metric.time_stamp))
            self._metric_batch.add_measurement_query(
                measurement_insert_bound_stmt)

            return metric

    def write_batch(self, metrics):

        with self._lock:
            batch_list = self._metric_batch.get_all_batches()

            results = execute_concurrent(self._session,
                                         batch_list,
                                         raise_on_first_error=True)

            self._handle_results(results)

            self._metric_batch.clear()

            LOG.info("flushed %s metrics", len(metrics))

    @staticmethod
    def _handle_results(results):
        for (success, result) in results:
            if not success:
                raise result

    def _load_dimension_cache(self):

        rows = self._session.execute(RETRIEVE_DIMENSION_CQL)

        if not rows:
            return

        for row in rows:
            key = self._get_dimnesion_key(row.region, row.tenant_id, row.name,
                                          row.value)
            self._dimension_cache[key] = key

        LOG.info(
            "loaded %s dimension entries cache from database into cache." %
            self._dimension_cache.currsize)

    @staticmethod
    def _get_dimnesion_key(region, tenant_id, name, value):
        return '%s\0%s\0%s\0%s' % (region, tenant_id, name, value)

    def _load_metric_dimension_cache(self):
        qm = token_range_query_manager.TokenRangeQueryManager(
            RETRIEVE_METRIC_DIMENSION_CQL,
            self._process_metric_dimension_query)

        token_ring = self._cluster.metadata.token_map.ring

        qm.query(token_ring)

    def _process_metric_dimension_query(self, rows):

        cnt = 0
        for row in rows:
            key = self._get_metric_dimnesion_key(row.region, row.tenant_id,
                                                 row.metric_name,
                                                 row.dimension_name,
                                                 row.dimension_value)
            self._metric_dimension_cache[key] = key
            cnt += 1

        LOG.info(
            "loaded %s metric dimension entries from database into cache." %
            cnt)
        LOG.info("total loaded %s metric dimension entries in cache." %
                 self._metric_dimension_cache.currsize)

    @staticmethod
    def _get_metric_dimnesion_key(region, tenant_id, metric_name,
                                  dimension_name, dimension_value):

        return '%s\0%s\0%s\0%s\0%s' % (region, tenant_id, metric_name,
                                       dimension_name, dimension_value)
예제 #22
0
class CompoundRequestUnit(FunctionUnit):
    @staticmethod
    def makeIt(ds_h, descr, before=None, after=None):
        unit_h = CompoundRequestUnit(ds_h, descr)
        ds_h.getEvalSpace()._insertUnit(unit_h, before=before, after=after)

    def __init__(self, ds_h, descr):
        FunctionUnit.__init__(self,
                              ds_h.getEvalSpace(),
                              descr,
                              sub_kind="comp-request",
                              parameters=["request", "approx", "state"])
        self.mZygSupport = ds_h.getZygositySupport()
        self.mOpCache = LRUCache(
            AnfisaConfig.configOption("comp-hets.cache.size"))

    def iterComplexCriteria(self, context, variants=None):
        if context is None:
            return
        yield "True", context["crit"]

    def makeInfoStat(self, eval_h, point_no):
        ret_handle = self.prepareStat()
        ret_handle["approx-modes"] = self.mZygSupport.getApproxInfo()
        ret_handle["labels"] = eval_h.getLabelPoints(point_no)
        ret_handle["family"] = self.mZygSupport.getNames()
        ret_handle["affected"] = self.mZygSupport.getAffectedGroup()
        return ret_handle

    def _locateContext(self, parameters, eval_h, point_no=None):
        if parameters.get("state"):
            actual_condition = eval_h.getLabelCondition(
                parameters["state"], point_no)
            if actual_condition is None:
                return None, ("State label %s not defined" %
                              parameters["state"])
        else:
            actual_condition = eval_h.getActualCondition(point_no)
        approx_mode = self.mZygSupport.normalizeApprox(
            parameters.get("approx"))
        if approx_mode is False:
            return None, "Improper approx mode %s" % parameters["approx"]

        c_rq = parameters.get("request")
        if self.mZygSupport.emptyRequest(c_rq):
            return None, "Empty request"

        build_id = md5(
            bytes(json.dumps(c_rq, sort_keys=True) + approx_mode + '|' +
                  actual_condition.hashCode(),
                  encoding="utf-8")).hexdigest()
        with self.getEvalSpace().getDS():
            context = self.mOpCache.get(build_id)
        if context is None:
            context = {
                "approx":
                approx_mode,
                "crit":
                self.mZygSupport.makeCompoundRequest(approx_mode,
                                                     actual_condition, c_rq,
                                                     self.getName())
            }
            with self.getEvalSpace().getDS():
                self.mOpCache[build_id] = context
        if context["crit"] is None:
            return context, "Too heavy condition"
        return context, None

    def locateContext(self, cond_data, eval_h):
        point_no, _ = eval_h.locateCondData(cond_data)
        context, err_msg = self._locateContext(cond_data[4], eval_h, point_no)
        if err_msg:
            eval_h.operationError(cond_data, err_msg)
        return context

    def validateArgs(self, parameters):
        if not parameters.get("request"):
            return "Argument request is required"
        return self.mZygSupport.validateRequest(parameters["request"])

    def makeParamStat(self, condition, parameters, eval_h, point_no):
        context, err_msg = self._locateContext(parameters, eval_h, point_no)
        ret_handle = self.prepareStat()
        if err_msg:
            ret_handle["err"] = err_msg
        else:
            self.collectComplexStat(
                ret_handle, condition, context,
                self.mZygSupport.getGeneUnit(
                    context.get("approx")).isDetailed())
        ret_handle.update(parameters)
        return ret_handle
예제 #23
0
class Miner:
    def __init__(
        self,
        consensus_type: ConsensusType,
        create_block_async_func: Callable[..., Awaitable[Optional[Block]]],
        add_block_async_func: Callable[[Block], Awaitable[None]],
        get_mining_param_func: Callable[[], Dict[str, Any]],
        get_header_tip_func: Callable[[], Header],
        remote: bool = False,
        root_signer_private_key: Optional[KeyAPI.PrivateKey] = None,
    ):
        """Mining will happen on a subprocess managed by this class

        create_block_async_func: takes no argument, returns a block (either RootBlock or MinorBlock)
        add_block_async_func: takes a block, add it to chain
        get_mining_param_func: takes no argument, returns the mining-specific params
        """
        self.consensus_type = consensus_type

        self.create_block_async_func = create_block_async_func
        self.add_block_async_func = add_block_async_func
        self.get_mining_param_func = get_mining_param_func
        self.get_header_tip_func = get_header_tip_func
        self.enabled = False
        self.process = None

        self.input_q = AioQueue()  # [(MiningWork, param dict)]
        self.output_q = AioQueue()  # [MiningResult]

        # header hash -> block under work
        # max size (tx max 258 bytes, gas limit 12m) ~= ((12m / 21000) * 258) * 128 = 18mb
        self.work_map = LRUCache(maxsize=128)

        if not remote and consensus_type != ConsensusType.POW_SIMULATE:
            Logger.warning("Mining locally, could be slow and error-prone")
        # remote miner specific attributes
        self.remote = remote
        # coinbase address -> header hash
        # key can be None, meaning default coinbase address from local config
        self.current_works = LRUCache(128)
        self.root_signer_private_key = root_signer_private_key

    def start(self):
        self.enabled = True
        self._mine_new_block_async()

    def is_enabled(self):
        return self.enabled

    def disable(self):
        """Stop the mining process if there is one"""
        if self.enabled and self.process:
            # end the mining process
            self.input_q.put((None, {}))
        self.enabled = False

    def _mine_new_block_async(self):
        async def handle_mined_block():
            while True:
                res = await self.output_q.coro_get()  # type: MiningResult
                if not res:
                    return  # empty result means ending
                # start mining before processing and propagating mined block
                self._mine_new_block_async()
                block = self.work_map[res.header_hash]
                block.header.nonce = res.nonce
                block.header.mixhash = res.mixhash
                del self.work_map[res.header_hash]
                self._track(block)
                try:
                    # FIXME: Root block should include latest minor block headers while it's being mined
                    # This is a hack to get the latest minor block included since testnet does not check difficulty
                    if self.consensus_type == ConsensusType.POW_SIMULATE:
                        block = await self.create_block_async_func(
                            Address.create_empty_account())
                        block.header.nonce = random.randint(0, 2**32 - 1)
                        self._track(block)
                        self._log_status(block)
                    await self.add_block_async_func(block)
                except Exception:
                    Logger.error_exception()

        async def mine_new_block():
            """Get a new block and start mining.
            If a mining process has already been started, update the process to mine the new block.
            """
            block = await self.create_block_async_func(
                Address.create_empty_account())
            if not block:
                self.input_q.put((None, {}))
                return
            mining_params = self.get_mining_param_func()
            mining_params["consensus_type"] = self.consensus_type
            # handle mining simulation's timing
            if "target_block_time" in mining_params:
                target_block_time = mining_params["target_block_time"]
                mining_params["target_time"] = (
                    block.header.create_time +
                    self._get_block_time(block, target_block_time))
            work = MiningWork(
                block.header.get_hash_for_mining(),
                block.header.height,
                block.header.difficulty,
            )
            self.work_map[work.hash] = block
            if self.process:
                self.input_q.put((work, mining_params))
                return

            self.process = AioProcess(
                target=self.mine_loop,
                args=(work, mining_params, self.input_q, self.output_q),
            )
            self.process.start()
            await handle_mined_block()

        # no-op if enabled or mining remotely
        if not self.enabled or self.remote:
            return None
        return asyncio.ensure_future(mine_new_block())

    async def get_work(self,
                       coinbase_addr: Address,
                       now=None) -> (MiningWork, Block):
        if not self.remote:
            raise ValueError("Should only be used for remote miner")

        if now is None:  # clock open for mock
            now = time.time()

        block = None
        header_hash = self.current_works.get(coinbase_addr)
        if header_hash:
            block = self.work_map.get(header_hash)
        tip_hash = self.get_header_tip_func().get_hash()
        if (not block  # no work cache
                or block.header.hash_prev_block != tip_hash  # cache outdated
                or now - block.header.create_time > 10  # stale
            ):
            block = await self.create_block_async_func(coinbase_addr,
                                                       retry=False)
            if not block:
                raise RuntimeError("Failed to create block")
            header_hash = block.header.get_hash_for_mining()
            self.current_works[coinbase_addr] = header_hash
            self.work_map[header_hash] = block

        header = block.header
        return (
            MiningWork(header_hash, header.height, header.difficulty),
            copy.deepcopy(block),
        )

    async def submit_work(
        self,
        header_hash: bytes,
        nonce: int,
        mixhash: bytes,
        signature: Optional[bytes] = None,
    ) -> bool:
        if not self.remote:
            raise ValueError("Should only be used for remote miner")

        if header_hash not in self.work_map:
            return False
        # this copy is necessary since there might be multiple submissions concurrently
        block = copy.deepcopy(self.work_map[header_hash])
        header = block.header

        # reject if tip updated
        tip_hash = self.get_header_tip_func().get_hash()
        if header.hash_prev_block != tip_hash:
            del self.work_map[header_hash]
            return False

        header.nonce, header.mixhash = nonce, mixhash
        # sign using the root_signer_private_key
        if self.root_signer_private_key and isinstance(block, RootBlock):
            header.sign_with_private_key(self.root_signer_private_key)

        # remote sign as a guardian
        if isinstance(block, RootBlock) and signature is not None:
            header.signature = signature

        try:
            await self.add_block_async_func(block)
            # a previous submission of the same work could have removed the key
            if header_hash in self.work_map:
                del self.work_map[header_hash]
            return True
        except Exception:
            Logger.error_exception()
            return False

    @staticmethod
    def mine_loop(
        work: Optional[MiningWork],
        mining_params: Dict,
        input_q: Queue,
        output_q: Queue,
        debug=False,
    ):
        consensus_to_mining_algo = {
            ConsensusType.POW_SIMULATE: Simulate,
            ConsensusType.POW_ETHASH: Ethash,
            ConsensusType.POW_QKCHASH: Qkchash,
            ConsensusType.POW_DOUBLESHA256: DoubleSHA256,
        }
        progress = {}

        def debug_log(msg: str, prob: float):
            if not debug:
                return
            random.random() < prob and print(msg)

        try:
            # outer loop for mining forever
            while True:
                # empty work means termination
                if not work:
                    output_q.put(None)
                    return

                debug_log("outer mining loop", 0.1)
                consensus_type = mining_params["consensus_type"]
                mining_algo_gen = consensus_to_mining_algo[consensus_type]
                mining_algo = mining_algo_gen(work, **mining_params)
                # progress tracking if mining param contains shard info
                if "full_shard_id" in mining_params:
                    full_shard_id = mining_params["full_shard_id"]
                    # skip blocks with height lower or equal
                    if (full_shard_id in progress
                            and progress[full_shard_id] >= work.height):
                        # get newer work and restart mining
                        debug_log("stale work, try to get new one", 1.0)
                        work, mining_params = input_q.get(block=True)
                        continue

                rounds = mining_params.get("rounds", 100)
                start_nonce = random.randint(0, MAX_NONCE)
                # inner loop for iterating nonce
                while True:
                    if start_nonce > MAX_NONCE:
                        start_nonce = 0
                    end_nonce = min(start_nonce + rounds, MAX_NONCE + 1)
                    res = mining_algo.mine(start_nonce,
                                           end_nonce)  # [start, end)
                    debug_log("one round of mining", 0.01)
                    if res:
                        debug_log("mining success", 1.0)
                        output_q.put(res)
                        if "full_shard_id" in mining_params:
                            progress[
                                mining_params["full_shard_id"]] = work.height
                        work, mining_params = input_q.get(block=True)
                        break  # break inner loop to refresh mining params
                    # no result for mining, check if new work arrives
                    # if yes, discard current work and restart
                    try:
                        work, mining_params = input_q.get_nowait()
                        break  # break inner loop to refresh mining params
                    except QueueEmpty:
                        debug_log("empty queue", 0.1)
                        pass
                    # update param and keep mining
                    start_nonce += rounds
        except:
            from sys import exc_info

            exc_type, exc_obj, exc_trace = exc_info()
            print("exc_type", exc_type)
            print("exc_obj", exc_obj)
            print("exc_trace", exc_trace)

    @staticmethod
    def _track(block: Block):
        """Post-process block to track block propagation latency"""
        tracking_data = json.loads(block.tracking_data.decode("utf-8"))
        tracking_data["mined"] = time_ms()
        block.tracking_data = json.dumps(tracking_data).encode("utf-8")

    @staticmethod
    def _log_status(block: Block):
        is_root = isinstance(block, RootBlock)
        full_shard_id = "R" if is_root else block.header.branch.get_full_shard_id(
        )
        count = len(block.minor_block_header_list) if is_root else len(
            block.tx_list)
        elapsed = time.time() - block.header.create_time
        Logger.info_every_sec(
            "[{}] {} [{}] ({:.2f}) {}".format(
                full_shard_id,
                block.header.height,
                count,
                elapsed,
                block.header.get_hash().hex(),
            ),
            60,
        )

    @staticmethod
    def _get_block_time(block: Block, target_block_time) -> float:
        if isinstance(block, MinorBlock):
            # Adjust the target block time to compensate computation time
            gas_used_ratio = block.meta.evm_gas_used / block.header.evm_gas_limit
            target_block_time = target_block_time * (1 - gas_used_ratio * 0.4)
            Logger.debug("[{}] target block time {:.2f}".format(
                block.header.branch.get_full_shard_id(), target_block_time))
        return numpy.random.exponential(target_block_time)
예제 #24
0
class GoogleHangoutsChatBackend(ErrBot):
    def __init__(self, config):
        super().__init__(config)
        identity = config.BOT_IDENTITY
        self.at_name = config.BOT_PREFIX
        self.creds_file = identity['GOOGLE_CREDS_FILE']
        self.gce_project = identity['GOOGLE_CLOUD_ENGINE_PROJECT']
        self.gce_topic = identity['GOOGLE_CLOUD_ENGINE_PUBSUB_TOPIC']
        self.gce_subscription = identity[
            'GOOGLE_CLOUD_ENGINE_PUBSUB_SUBSCRIPTION']
        self.chat_api = GoogleHangoutsChatAPI(self.creds_file)
        self.bot_identifier = HangoutsChatUser(None, self.at_name, None, None)
        self.message_cache = LRUCache(1024)
        self.md = hangoutschat_markdown_converter()

    def _subscribe_to_pubsub_topic(self, project, topic_name,
                                   subscription_name, callback):
        subscriber = pubsub.SubscriberClient()
        subscription_name = 'projects/{}/subscriptions/{}'.format(
            project, subscription_name)
        log.info("Subscribed to {}".format(subscription_name))
        return subscriber.subscribe(subscription_name, callback=callback)

    def _handle_message(self, message):
        try:
            data = json.loads(message.data.decode('utf-8'))
        except Exception:
            log.warning('Received malformed message: {}'.format(message.data))
            message.ack()
            return

        if not data.get('message') or not data.get('message', {}).get('text'):
            message.ack()
            return
        sender_blob = data['message']['sender']
        sender = HangoutsChatUser(sender_blob['name'],
                                  sender_blob['displayName'],
                                  sender_blob['email'], sender_blob['type'])
        message_body = data['message']['text']
        message.ack()
        # message.ack() may fail silently, so we should ensure our messages are somewhat idempotent
        time = data.get('eventTime', 0)
        if time == 0:
            log.warning('Received 0 eventTime from message')

        send_name = sender_blob.get('name', '')
        thread_name = data.get('message', {}).get('thread', {}).get('name', '')
        body_length = len(message_body)
        message_id = "{}{}{}{}".format(time, send_name, thread_name,
                                       body_length)
        cached = self.message_cache.get(message_id)
        if cached is not None:
            return
        self.message_cache[message_id] = True

        context = {
            'space_id': data['space']['name'],
            'thread_id': data['message']['thread']['name']
        }
        msg = Message(body=message_body.strip(), frm=sender, extras=context)
        is_dm = data['message']['space']['type'] == 'DM'
        if is_dm:
            msg.to = self.bot_identifier
        self.callback_message(msg)

    def _split_message(
            self,
            text,
            maximum_message_length=GoogleHangoutsChatAPI.max_message_length):
        '''
        Splits a given string up into multiple strings all of length less than some maximum size

        Edge Case: We don't handle the case where one line is big enough for a whole message
        '''
        lines = text.split('\n')
        messages = []
        current_message = ''
        for line in lines:
            if len(current_message) + len(line) + 1 > maximum_message_length:
                messages.append(current_message)
                current_message = line + '\n'
            else:
                current_message += line + '\n'

        messages.append(current_message)
        return messages

    def send_message(self, message):
        super(GoogleHangoutsChatBackend, self).send_message(message)
        log.info("Sending {}".format(message.body))
        space_id = message.extras.get('space_id', None)
        convert_markdown = message.extras.get('markdown', True)
        if not space_id:
            log.info(message.body)
            return
        thread_id = message.extras.get('thread_id', None)
        thread_key = message.extras.get('thread_key', None)
        mentions = message.extras.get('mentions', None)
        text = message.body
        if convert_markdown:
            text = self.md.convert(message.body)
        sub_messages = self._split_message(text)
        log.info("Split message into {} parts".format(len(sub_messages)))
        for message in sub_messages:
            message_payload = {'text': message}
            if mentions:
                message_payload['annotations'] = []
                for mention in mentions:
                    message_payload['annotations'].append({
                        "type":
                        "USER_MENTION",
                        "startIndex":
                        mention['start'],
                        "length":
                        mention['length'],
                        "userMention": {
                            "user": {
                                "name": mention['user_id'],
                                "displayName": mention['display_name'],
                                "type": "HUMAN"
                            },
                            "type": "ADD"
                        }
                    })
            if thread_id:
                message_payload['thread'] = {'name': thread_id}

            self.chat_api.create_message(space_id, message_payload, thread_key)

    def send_card(self, cards, space_id, thread_id=None):
        log.info("Sending card")
        message_payload = {'cards': cards}
        if thread_id:
            message_payload['thread'] = {'name': thread_id}

        self.chat_api.create_message(space_id, message_payload)

    def serve_forever(self):
        subscription = self._subscribe_to_pubsub_topic(self.gce_project,
                                                       self.gce_topic,
                                                       self.gce_subscription,
                                                       self._handle_message)
        self.connect_callback()

        try:
            import time
            while True:
                time.sleep(10)
        except KeyboardInterrupt:
            log.info("Exiting")
        finally:
            self.disconnect_callback()
            self.shutdown()

    def build_identifier(self, strrep):
        return HangoutsChatUser(None, strrep, None, None)

    def build_reply(self, msg, text=None, private=False, threaded=False):
        response = Message(body=text,
                           frm=msg.to,
                           to=msg.frm,
                           extras=msg.extras)
        return response

    def change_presence(self, status='online', message=''):
        return None

    @property
    def mode(self):
        return 'Google_Hangouts_Chat'

    def query_room(self, room):
        return HangoutsChatRoom(room, self.chat_api)

    def rooms(self):
        spaces = self.chat_api.get_spaces()
        rooms = [
            '{} ({})'.format(space['displayName'], space['name'])
            for space in list(spaces) if space['type'] == 'ROOM'
        ]

        return rooms
예제 #25
0
class IpfsRSAAgent:
    """
    IPFS RSA Agent

    :param rsaExecutor: RSA Executor object
    :param pubKeyPem: PEM-encoded public key
    :param privKeyPath: Path to private RSA key
    """
    def __init__(self,
                 rsaExecutor,
                 pubKeyPem,
                 privKeyPath,
                 privKeyPassword=None):
        self.rsaExec = rsaExecutor
        self.pubKeyPem = pubKeyPem
        self._pubKeyCidCached = None
        self.privKeyPath = privKeyPath
        self._privKeyCache = LRUCache(4)

    def debug(self, msg):
        log.debug('RSA Agent: {0}'.format(msg))

    @property
    def pubKeyCidCached(self):
        return self._pubKeyCidCached

    @ipfsOp
    async def pubKeyCid(self, ipfsop):
        if self.pubKeyCidCached and cidValid(self.pubKeyCidCached):
            return self.pubKeyCidCached

        try:
            entry = await ipfsop.addBytes(self.pubKeyPem)
            self._pubKeyCidCached = entry['Hash']
            return self.pubKeyCidCached
        except Exception as err:
            self.debug(f'Cannot import pubkey: {err}')

    async def privJwk(self):
        try:
            privKey = await self._privateKey()
            pem = privKey.export_key(pkcs=8)
            key = jwk.JWK()
            key.import_from_pem(pem)
            return key
        except Exception as err:
            self.debug(f'Cannot create priv JWK key: {err}')
            return None

    async def jwsToken(self, payload: str):
        try:
            jwk = await self.privJwk()
            token = jws.JWS(payload.encode('utf-8'))
            token.add_signature(jwk, None, json_encode({"alg": "RS256"}),
                                json_encode({"kid": jwk.thumbprint()}))
            return token
        except Exception as err:
            self.debug(f'Cannot create JWS token: {err}')

    async def jwsTokenObj(self, payload: str):
        token = await self.jwsToken(payload)
        if token:
            return orjson.loads(token.serialize())

    async def jwtCreate(self, claims, alg='RS256'):
        try:
            jwk = await self.privJwk()
            token = jwt.JWT(header={"alg": alg}, claims=claims)
            token.make_signed_token(jwk)
            return token.serialize()
        except Exception as err:
            self.debug(f'Cannot create JWT: {err}')

    async def encrypt(self, data, pubKey, sessionKey=None, cacheKey=False):
        return await self.rsaExec.encryptData(
            data if isinstance(data, BytesIO) else BytesIO(data),
            pubKey,
            sessionKey=sessionKey,
            cacheKey=cacheKey)

    async def decrypt(self, data):
        return await self.rsaExec.decryptData(BytesIO(data), await
                                              self._privateKey())

    @ipfsOp
    async def storeSelf(self, op, data, offline=False, wrap=False):
        """
        Encrypt some data with our pubkey and store it in IPFS

        Returns the IPFS entry (returned by 'add') of the file

        :param bytes data: data to encrypt
        :param bool offline: offline mode (no announce)

        :rtype: dict
        """
        try:
            encrypted = await self.encrypt(data, self.pubKeyPem)
            if encrypted is None:
                return

            entry = await op.addBytes(encrypted, offline=offline, wrap=wrap)
            if entry:
                self.debug('storeSelf: encoded to {0}'.format(entry['Hash']))
                return entry
        except aioipfs.APIError as err:
            self.debug('IPFS error {}'.format(err.message))

    @ipfsOp
    async def encryptToMfs(self, op, data, mfsPath):
        try:
            encrypted = await self.encrypt(data, self.pubKeyPem)
            if not encrypted:
                return False
            return await op.filesWrite(mfsPath,
                                       encrypted,
                                       create=True,
                                       truncate=True)
        except aioipfs.APIError as err:
            self.debug('IPFS error {}'.format(err.message))

    @ipfsOp
    async def encryptJsonToMfs(self, op, obj, mfsPath):
        try:
            return await self.encryptToMfs(orjson.dumps(obj), mfsPath)
        except aioipfs.APIError as err:
            self.debug('IPFS error {}'.format(err.message))

    @ipfsOp
    async def decryptIpfsObject(self, op, data):
        privKey = await self._privateKey()
        try:
            decrypted = await self.rsaExec.decryptData(BytesIO(data), privKey)
            if decrypted:
                self.debug('RSA: decrypted {0} bytes'.format(len(decrypted)))
                return decrypted
        except aioipfs.APIError as err:
            self.debug('decryptIpfsObject: IPFS error {}'.format(err.message))
        except Exception as e:
            self.debug('RSA: unknown error while decrypting {}'.format(str(e)))

    @ipfsOp
    async def decryptMfsFile(self, op, path):
        try:
            data = await op.client.files.read(path)
            if data is None:
                raise ValueError('Invalid file')
        except aioipfs.APIError as err:
            self.debug('decryptMfsFile failed for {0}, '
                       'IPFS error was {1}'.format(path, err.message))
        else:
            return await self.decryptIpfsObject(data)

    @ipfsOp
    async def decryptMfsJson(self, op, path):
        try:
            decrypted = await self.decryptMfsFile(path)
            if decrypted:
                return json.loads(decrypted.decode())
        except aioipfs.APIError as err:
            self.debug('decryptMfsJson failed for {0}, '
                       'IPFS error was {1}'.format(path, err.message))

    @ipfsOp
    async def pssSign(self, op, message):
        return await self.rsaExec.pssSign(message, await self._privateKey())

    @ipfsOp
    async def pssSignImport(self, op, message, pin=False):
        signed = await self.rsaExec.pssSign(message, await self._privateKey())

        if signed:
            try:
                entry = await op.addBytes(signed, pin=pin)
                return entry['Hash']
            except Exception:
                return None

    @ipfsOp
    async def pssSign64(self, op, message):
        """
        :rtype: str
        """

        signed = await self.pssSign(message)

        if isinstance(signed, bytes):
            return base64.b64encode(signed).decode()

    async def __rsaReadPrivateKeyUtf8(self):
        key = await asyncReadFile(self.privKeyPath, mode='rt')
        return key.encode('utf-8')

    async def __rsaReadPrivateKey(self):
        return await asyncReadFile(self.privKeyPath)

    async def privKeyUnlock(self, passphrase=None):
        key = await self.rsaExec.importKey(await
                                           asyncReadFile(self.privKeyPath),
                                           passphrase=passphrase)

        if key:
            self.debug('Private key unlock success, caching')
            self._privKeyCache[0] = key
            self.debug(f'Key cache size: {len(self._privKeyCache)}')
            return key

    async def _privateKey(self, key=0):
        pKey = self._privKeyCache.get(key)
        if pKey:
            return pKey

        return await self.rsaExec.importKey(await
                                            asyncReadFile(self.privKeyPath))
예제 #26
0
class ImageDataGeneration:
    """It has functionality to create generators to feed data to keras.
    """
    valid_subsets = frozenbidict({
        'training': ImageDataSubset.Training,
        'validation': ImageDataSubset.Validation,
        'prediction': ImageDataSubset.Prediction
    })

    def __init__(self,
                 dataframe,
                 input_params,
                 image_generation_params,
                 transformer=None,
                 randomize=True):
        """It initializes the dataframe object.

        Arguments:
            dataframe {Pandas DataFrame} -- A pandas dataframe object with columnar data with image names and labels.
            input_params {A InputDataParameter object} -- An input parameter object.
            image_generation_params {A ImageGenerationParameters object} -- A training data parameter object.
            transformer {A ImageDataTransformation object} -- It is used to transform the image objects.
            randomize {boolean} -- It indicates randomization of the input dataframe.
        """
        #Required parameters
        self._dataframe = dataframe
        self._input_params = input_params
        self._image_generation_params = image_generation_params

        #Optional parameters
        self._transformer = transformer
        self._randomize = randomize

        #Caching
        self._image_cache = LRUCache(
            self._image_generation_params.image_cache_size)

        #Logging
        self._logger = logging.get_logger(__name__)

        #Metrics
        self._load_slice_metric = 'get_image_objects'

        #Create metrics
        Metric.create(self._load_slice_metric)

        #Compute the training and validation boundary using the validation split.
        boundary = int(
            ceil(
                len(self._dataframe) *
                (1. - self._image_generation_params.validation_split)))
        self._logger.info(
            "Validation split: {} Identified boundary: {}".format(
                self._image_generation_params.validation_split, boundary))

        #Split the dataframe into training and validation.
        self._main_df = self._dataframe.loc[:(boundary - 1), :]
        self._validation_df = self._dataframe.loc[boundary:, :].reset_index(
            drop=True)

        n_dataframe = len(self._dataframe)
        n_main_df = len(self._main_df)
        n_validation_df = len(self._validation_df)

        self._logger.info(
            "Dataframe size: {} main set size: {} validation size: {}".format(
                n_dataframe, n_main_df, n_validation_df))

    def _get_images(self, n_images):
        """It extracts the image names from the dataframe.

        Arguments:
            n_images {An numpy.array object} -- It is a 4-D numpy array containing image data.
        """
        df_size = len(self._main_df)
        loop_count = 0
        images = set()

        while len(images) <= n_images and loop_count < df_size:
            random_index = randrange(df_size)

            for image_col in self._image_generation_params.image_cols:
                images.add(self._main_df.loc[random_index, image_col])

            loop_count += 1

        return list(images)

    def fit(self, n_images):
        """It calculates statistics on the input dataset. These are used to perform transformation.

        Arguments:
            n_images {An numpy.array object} -- It is a 4-D numpy array containing image data.
        """
        if n_images <= 0:
            ValueError(
                "Expected a positive integer for n_images. Got: {}".format(
                    n_images))

        #Input list for data fitting
        images = self._get_images(n_images)

        self._logger.info("%d images to use for data fitting", len(images))

        #Image objects
        img_objs_map = self._get_image_objects(images)
        img_objs = np.asarray(list(img_objs_map.values()))

        self._logger.info(
            "fit:: images: {} to the transformer to compute statistics".format(
                img_objs.shape))

        #Fit the data in the transformer
        self._transformer.fit(img_objs)

    def flow(self, subset='training'):
        """It creates an iterator to the input dataframe.

        Arguments:
            subset {string} -- A string to indicate select between training and validation splits.
        """
        #Validate subset parameter
        if not ImageDataGeneration.valid_subsets.get(subset):
            raise ValueError("Valid values of subset are: {}".format(
                list(ImageDataGeneration.valid_subsets.keys())))

        #Qualified subset
        q_subset = ImageDataGeneration.valid_subsets[subset]

        #Dataframe placeholder
        dataframe = None

        #Pick the correct dataframe
        if q_subset == ImageDataSubset.Training or q_subset == ImageDataSubset.Prediction:
            dataframe = self._main_df
        elif q_subset == ImageDataSubset.Validation:
            dataframe = self._validation_df

        self._logger.info("flow:: subset: {} dataset size: {}".format(
            subset, len(dataframe)))

        return ImageDataIterator(self,
                                 dataframe,
                                 self._image_generation_params.batch_size,
                                 q_subset,
                                 randomize=self._randomize)

    def _load_subset_slice(self, df_slice, subset):
        """It loads the image objects and the labels for the data frame slice.

        Arguments:
            df_slice {A pandas.DataFrame object} -- A pandas DataFrame object containing input data and labels.

        Returns:
            {An object} -- A list of image objects in prediction phase. A tuple of image objects and their labels in training phase.
        """
        self._logger.info('Using subset: %s', subset)

        #Results placeholder
        results = None

        #Load the slice
        if subset == ImageDataSubset.Training or subset == ImageDataSubset.Validation:
            results = self._load_train_phase_slice(df_slice)
        elif subset == ImageDataSubset.Prediction:
            results = self._load_predict_phase_slice(df_slice)

        return results

    def _load_train_phase_slice(self, df_slice):
        """It loads the image objects and the labels for the data frame slice.

        Arguments:
            df_slice {A pandas.DataFrame object} -- A pandas DataFrame object containing input data and labels.

        Returns:
            (Numpy object, Numpy object) -- A tuple of input data and labels.
        """
        return self._load_slice(df_slice)

    def _load_predict_phase_slice(self, df_slice):
        """It loads the image objects for the data frame slice.

        Arguments:
            df_slice {A pandas.DataFrame object} -- A pandas DataFrame object containing input data and labels.

        Returns:
            (Numpy object, Numpy object) -- A tuple of input data and labels.
        """
        images, _ = self._load_slice(df_slice)

        return images

    def _load_slice(self, df_slice):
        """It loads the image objects for the data frame slice.

        Arguments:
            df_slice {A pandas.DataFrame object} -- A pandas DataFrame object containing input data and labels.

        Returns:
            (Numpy object, Numpy object) -- A tuple of input data and labels.
        """
        #Calculate the number of classes
        num_classes = self._image_generation_params.num_classes

        #Process labels
        df_slice_y = df_slice[self._image_generation_params.label_col].values
        df_slice_y_categorical = to_categorical(
            df_slice_y,
            num_classes=num_classes) if num_classes > 2 else df_slice_y

        #Process image columns
        df_slice_x = []

        for x_col in self._image_generation_params.image_cols:
            images = df_slice[x_col].tolist()

            #Load images
            img_objs_map = self._get_image_objects(images)

            #Arrange them in the input order
            img_objs = [img_objs_map[image] for image in images]
            img_objs = np.asarray(img_objs)

            if x_col in self._image_generation_params.image_transform_cols:
                img_objs = self._apply_transformation(img_objs)

            df_slice_x.append(img_objs)

        return (df_slice_x, df_slice_y_categorical)

    def _get_image_objects(self, images):
        """It loads the image objects for the list of images.
        If the image is available, it is loaded from the cache.
        Otherwise, it is loaded from the disk.

        Arguments:
            images {[string]} -- A list of image names.
        """
        #Start recording time
        record_handle = Metric.start(self._load_slice_metric)

        img_objs = {}
        candidate_images = set(images)
        for image in candidate_images:
            #Get the image object for the current image from the cache.
            #Add to the dictionary, if it is not None.
            img_obj = self._image_cache.get(image)

            if img_obj is not None:
                img_objs[image] = img_obj

        #Create a list of missing images.
        cached_images = set(img_objs.keys())
        missing_images = [
            image for image in candidate_images if not image in cached_images
        ]

        self._logger.debug("Cached images: {} missing images: {}".format(
            cached_images, missing_images))

        #Load the missing image objects, and apply parameters.
        missing_img_objs = utils.imload(
            self._image_generation_params.dataset_location, missing_images,
            self._image_generation_params.input_shape[:2])
        missing_img_objs = self._apply_parameters(missing_img_objs)

        #Update the cache
        self._image_cache.update(zip(missing_images, missing_img_objs))

        #Update the image object dictionary with the missing image objects.
        for image, img_obj in zip(missing_images, missing_img_objs):
            img_objs[image] = img_obj

        #End recording time
        Metric.stop(record_handle, self._load_slice_metric)

        return img_objs

    def _apply_parameters(self, img_objs):
        """It processes image objects based on the input parameters.
        e.g. normalization, reshaping etc.

        Arguments:
            img_objs {numpy.ndarray} -- A numpy array of image objects.
        """
        if self._image_generation_params.normalize:
            img_objs = utils.normalize(img_objs)

        return img_objs

    def _apply_transformation(self, img_objs):
        transformed_objects = img_objs

        if self._transformer:
            transformed_objects = self._transformer.transform(img_objs)

        return transformed_objects