def test_cluster_already_filtered(filterable_cluster): # Filter our good cluster. with session_scope() as session: cluster = session.query(Cluster).first() fcluster = cluster.filter() session.add(fcluster) # Add check we can't filter it again. with pytest.raises(AlreadyFiltered): with session_scope() as session: fcluster = session.merge(fcluster) fcluster.filter()
def test_cluster_emptied(filterable_cluster): # Modify our cluster to make it bad. with session_scope() as session: quote = session.query(Quote).filter(Quote.sid == 0).one() timestamps = quote.url_timestamps.copy() timestamps[1] = datetime.utcnow() + timedelta(days=81) quote.url_timestamps = timestamps # Now check our cluster gets filtered out. with session_scope() as session: cluster = session.query(Cluster).first() assert cluster.filter() is None
def test_filter_clusters_emptied(filterable_cluster): # Modify our cluster to make it bad. with session_scope() as session: quote = session.query(Quote).filter(Quote.sid == 0).one() timestamps = quote.url_timestamps.copy() timestamps[1] = datetime.utcnow() + timedelta(days=81) quote.url_timestamps = timestamps # Check our cluster gets filtered out. filter_clusters() with session_scope() as session: assert session.query(Cluster)\ .filter(Cluster.filtered.is_(True)).count() == 0
def test_filter_clusters_limit(filterable_cluster): # Our cluster gets all its quotes filtered out but one (#0), # and is then kept. filter_clusters(limit=0) with session_scope() as session: assert session.query(Cluster)\ .filter(Cluster.filtered.is_(True)).count() == 0 filter_clusters(limit=1) with session_scope() as session: fcluster = session.query(Cluster)\ .filter(Cluster.filtered.is_(True)).one() assert fcluster.size == 1 assert fcluster.quotes.first().sid == 0
def load_mt_frequency_and_tokens(): """Compute MemeTracker frequency codings and the list of available tokens. Iterate through the whole MemeTracker dataset loaded into the database to count word frequency and make a list of tokens encountered. Frequency codings are then saved to :data:`~.settings.FREQUENCY`, and the list of tokens is saved to :data:`~.settings.TOKENS`. The MemeTracker dataset must have been loaded and filtered previously, or an excetion will be raised (see :ref:`usage` or :mod:`.cli` for more about that). Progress is printed to stdout. """ logger.info('Computing memetracker frequencies and token list') click.echo('Computing MemeTracker frequencies and token list...') # See if we should count frequency of tokens or lemmas. source_type, _ = SubstitutionFeaturesMixin.__features__['frequency'] logger.info('Frequencies will be computed on %s', source_type) with session_scope() as session: quote_ids = session.query(Quote.id).filter(Quote.filtered.is_(True)) # Check we have filtered quotes. if quote_ids.count() == 0: raise Exception('Found no filtered quotes, aborting.') quote_ids = [id for (id,) in quote_ids] # Compute frequencies and token list. frequencies = defaultdict(int) tokens = set() for quote_id in ProgressBar()(quote_ids): with session_scope() as session: quote = session.query(Quote).get(quote_id) tokens.update(quote.tokens) for word in getattr(quote, source_type): frequencies[word] += quote.frequency # Convert frequency back to a normal dict. frequencies = dict(frequencies) logger.debug('Saving memetracker frequencies to pickle') with open(settings.FREQUENCY, 'wb') as f: pickle.dump(frequencies, f) logger.debug('Saving memetracker token list to pickle') with open(settings.TOKENS, 'wb') as f: pickle.dump(tokens, f) click.secho('OK', fg='green', bold=True) logger.info('Done computing memetracker frequencies and token list')
def test_cluster(some_clusters): """Test base functionality of :class:`~.db.Cluster`.""" # Test empty cluster attributes. cluster = Cluster() assert cluster.size == 0 assert cluster.size_urls == 0 assert cluster.frequency == 0 with pytest.raises(ValueError) as excinfo: cluster.span assert 'No urls' in str(excinfo.value) assert cluster.urls == [] # Test clusters from database. with session_scope() as session: assert session.query(Cluster).count() == 5 assert session.query(Cluster.sid).all() == \ [(i,) for i in some_clusters] assert session.query(Cluster).filter_by(sid=0).one().size == 0 assert session.query(Cluster).filter_by(sid=0).one().size_urls == 0 assert session.query(Cluster).filter_by(sid=0).one().frequency == 0 with pytest.raises(ValueError) as excinfo: session.query(Cluster).filter_by(sid=0).one().span assert 'No urls' in str(excinfo.value) assert session.query(Cluster).filter_by(sid=0).one().urls == [] assert session.query(Cluster).get(1).format_copy() == \ '1\t0\tFalse\ttest'
def test_cluster_cascade_to_quotes(some_quotes): """Check deleting a :class:`~.db.Cluster` also deletes its :class:`~.db.Quote`\ s.""" with session_scope() as session: session.query(Cluster).delete() assert session.query(Quote).count() == 0
def test_quote_cascade_to_substitutions(some_substitutions): """Check deleting a :class:`~.db.Quote` also deletes its :class:`~.db.Substitution`\ s.""" with session_scope() as session: session.query(Quote).delete() assert session.query(Substitution).count() == 0
def _copy(string, table, columns): """Execute a PostgreSQL COPY command. COPY is one of the fastest methods to import data in bulk into PostgreSQL. This function executes this operation through the raw psycopg2 :class:`cursor` object. Parameters ---------- string : file-like object Contents of the data to import into the database, formatted for the COPY command (see `PostgreSQL's documentation <https://www.postgresql.org/docs/9.5/static/sql-copy.html>`_ for more details). Can be an :class:`io.StringIO` if you don't want to use a real file in the filesystem. table : str Name of the table into which the data is imported. columns : list of str List of the column names encoded in the `string` parameter. When `string` is produced using :meth:`Quote.format_copy` or :meth:`Cluster.format_copy` you can use the corresponding :attr:`Quote.format_copy_columns` or :attr:`Cluster.format_copy_columns` for this parameter. See Also -------- save_by_copy, Quote.format_copy, Cluster.format_copy """ string.seek(0) with session_scope() as session: cursor = session.connection().connection.cursor() cursor.copy_from(string, table, columns=columns)
def test_cluster_too_long(filterable_cluster): # Modify our cluster to make it too long after quote filtering. with session_scope() as session: cluster = session.query(Cluster).first() # This quote is all good, but is too far from quote sid=0, leading # the cluster span to be too long. quote = Quote(sid=5, string='a string with enough ' 'words and no problems') quote.add_url( Url(timestamp=datetime.utcnow() + timedelta(days=80, hours=1), frequency=2, url_type='M', url='some-url') ) cluster.quotes.append(quote) # Now check our cluster gets filtered out. with session_scope() as session: cluster = session.query(Cluster).first() assert cluster.filter() is None
def test_cluster_kept(filterable_cluster): # Our cluster gets all its quotes filtered out but one (#0), # and is then kept. with session_scope() as session: cluster = session.query(Cluster).first() fcluster = cluster.filter() assert fcluster.size == 1 assert fcluster.quotes.first().sid == 0
def test_quote_add_url_sealed(some_quotes): """Check you can't add :class:`~.db.Url`\ s to a :class:`~.db.Quote` where you've already access :class:`~.utils.cache`\ d attributes.""" with pytest.raises(SealedException): with session_scope() as session: u = Url(timestamp=datetime.utcnow(), frequency=1, url_type='B', url='some url 1') q = session.query(Quote).filter_by(sid=0).one() assert q.size == len(q.urls) q.add_url(u)
def test_url(some_urls): """Test base functionality of :class:`~.db.Url`.""" with session_scope() as session: q0 = session.query(Quote).filter_by(sid=0).one() q3 = session.query(Quote).filter_by(sid=3).one() basedate = datetime(year=2008, month=1, day=1) assert q0.urls[0].timestamp == basedate assert q3.urls[0].timestamp == basedate + timedelta(days=3) assert q0.size == 2 assert q0.frequency == 4 assert q0.span == timedelta(days=10) assert q0.urls[0].occurrence == 0 assert q0.urls[1].occurrence == 1 assert q3.urls[0].occurrence == 0 assert q3.urls[1].occurrence == 1 c0 = session.query(Cluster).filter_by(sid=0).one() assert c0.size == 2 assert c0.size_urls == 4 assert c0.frequency == 8 assert c0.span == timedelta(days=15) assert c0.urls[0].timestamp == basedate assert q0.format_copy() == \ ('{}'.format(q0.id) + '\t1\t0\tFalse\tSome quote to tokenize 0\t' '{2008-01-01 00:00:00, 2008-01-11 00:00:00}\t' '{2, 2}\t{B, B}\t' '{"Url with \\\\" and \' 0", "Url with \\\\" and \' 10"}') with pytest.raises(DataError): with session_scope() as session: quote = session.query(Quote).filter_by(sid=1).one() quote.add_url(Url(timestamp=datetime.now(), frequency=1, url_type='C', url='some url'))
def some_substitutions(some_clusters, some_quotes, some_urls): """Get a handle to a temporary database filled with a few clusters, quotes, urls, and substitutions all wiped on teardown. The substitutions are assigned to two different substitution models (although their actual occurrences don't fit with those models). See the source code if you need details on the clusters', quotes', and urls' exact attributes. """ model1 = Model(Time.discrete, Source.majority, Past.last_bin, Durl.all, 1) model2 = Model(Time.discrete, Source.majority, Past.all, Durl.all, 1) with session_scope() as session: c0 = session.query(Cluster).filter_by(sid=0).one() c1 = session.query(Cluster).filter_by(sid=1).one() q10 = Quote(sid=10, cluster=c0, string="Don't do it! I know I wouldn't") q11 = Quote(sid=11, cluster=c0, string="I know I hadn't") q12 = Quote(sid=12, cluster=c0, string="some string") q13 = Quote(sid=13, cluster=c0, string="some other string") q14 = Quote(sid=14, cluster=c1, string="some other string 2") q15 = Quote(sid=15, cluster=c1, string="some other string 3") session.add(q10) session.add(q11) session.add(q12) session.add(q13) session.add(q14) session.add(q15) session.add(Substitution(source=q10, destination=q11, occurrence=0, start=5, position=3, model=model1)) # Same durl (destination, occurrence) as above, but different source, # different start and different destination position. session.add(Substitution(source=q12, destination=q11, occurrence=0, start=2, position=2, model=model1)) # Same destination but different occurrence (so different durl). session.add(Substitution(source=q13, destination=q11, occurrence=1, start=5, position=3, model=model1)) # Different destination altogether. session.add(Substitution(source=q13, destination=q12, occurrence=0, start=0, position=1, model=model2)) # Different cluster. session.add(Substitution(source=q14, destination=q15, occurrence=0, start=0, position=1, model=model2)) return model1, model2
def filter_quote_offset(): """Get the offset to add to filtered :class:`~.db.Quote` ids. A filtered :class:`~.db.Quote`'s id will be its original :class:`~.db.Quote`'s id plus this offset. The function is :func:`~.utils.memoized` since it is called so often. """ from brainscopypaste.db import Quote with session_scope() as session: maxid = session.query(func.max(Quote.id)).scalar() return _top_id(maxid)
def some_clusters(tmpdb): """Get a handle to a temporary database filled with a few empty clusters, wiped on teardown. See the source code if you need details on the clusters' exact attributes. """ sids = range(5) with session_scope() as session: session.add_all(Cluster(sid=i, source='test') for i in sids) return sids
def filterable_cluster(tmpdb): with session_scope() as session: cluster = Cluster(sid=0, source='test') cluster.quotes = [ # Quote 0 is good. Quote(sid=0, string='a string with enough words and no problems'), # Quote 1 has not enough words. Quote(sid=1, string='not enough words here'), # Quote 2 is not in English. Quote(sid=2, string="ceci n'est pas de l'anglais " "mais a assez de mots"), # Quote 3 spans too long. Quote(sid=3, string="a quote that spans too long"), # Quote 4 has no urls at all. Quote(sid=4, string="a quote without any urls") ] # Quote 0 is good. cluster.quotes[0].add_urls([ Url(timestamp=datetime.utcnow(), frequency=2, url_type='M', url='some-url'), Url(timestamp=datetime.utcnow() + timedelta(days=80, hours=-1), frequency=2, url_type='M', url='some-url') ]) # Quote 1 has not enough words. cluster.quotes[1].add_urls([ Url(timestamp=datetime.utcnow(), frequency=2, url_type='M', url='some-url') ]) # Quote 2 is not in English. cluster.quotes[2].add_urls([ Url(timestamp=datetime.utcnow(), frequency=2, url_type='M', url='some-url') ]) # Quote 3 spans too long. cluster.quotes[3].add_urls([ Url(timestamp=datetime.utcnow(), frequency=2, url_type='M', url='some-url'), Url(timestamp=datetime.utcnow() + timedelta(days=80, hours=1), frequency=2, url_type='M', url='some-url') ]) # Quote 4 has no urls at all. # Save all this. session.add(cluster)
def drop_filtered(obj): """Drop filtered rows (Clusters, Quotes).""" click.secho('Dropping filtered rows will also drop any substitutions ' 'mined beforehand', bold=True) if confirm('the filtered rows (clusters, quotes) and ' 'any mined substitutions attached to them'): logger.info('Dropping filtered rows (quotes and clusters) and ' 'substitutions from database') with session_scope() as session: click.secho('Dropping filtered rows and substitutions... ', nl=False) session.query(Cluster).filter(Cluster.filtered.is_(True))\ .delete(synchronize_session=False) click.secho('OK', fg='green', bold=True) logger.info('Done dropping filtered rows and substitutions')
def _check(self): """Check the consistency of the database with `self._checks`. The original MemeTracker dataset specifies the number of quotes and frequency for each cluster, and the number of urls and frequency for each quote. This information is saved in `self._checks` during parsing. This method iterates through the whole database of saved :class:`~.db.Cluster`\ s and :class:`~.db.Quote`\ s to check that their counts correspond to what the MemeTracker dataset says (as stored in `self._checks`). Raises ------ ValueError If any count in the database differs from its specification in `self._checks`. """ for id, check in ProgressBar()(self._checks.items()): logger.debug('Checking cluster #%s consistency', id) with session_scope() as session: # Check the cluster itself. cluster = session.query(Cluster).get(id) err_end = (' #{} does not match value' ' in file').format(cluster.sid) if check['cluster']['size'] != cluster.size: raise ValueError("Cluster size" + err_end) if check['cluster']['frequency'] != cluster.frequency: raise ValueError("Cluster frequency" + err_end) # Check each quote. for quote in cluster.quotes: quote_check = check['quotes'][quote.id] err_end = (' #{} does not match value' ' in file').format(quote.sid) if quote_check['size'] != quote.size: raise ValueError("Quote size" + err_end) if quote_check['frequency'] != quote.frequency: raise ValueError("Quote frequency" + err_end) self._checks = {}
def some_quotes(some_clusters): """Get a handle to a temporary database filled with a few clusters and quotes, wiped on teardown. See the source code if you need details on the clusters' and quotes' exact attributes. """ sids = range(10) with session_scope() as session: clusters = session.query(Cluster) # Insert quotes in reverse order to check ordering session.add_all(Quote(sid=i, cluster=clusters.filter_by(sid=i % 5).one(), string='Some quote to tokenize {}'.format(i)) for i in sids[::-1]) return sids
def test_clone_cluster(some_urls): """Test cloning of a :class:`~.db.Cluster`.""" with session_scope() as session: cluster = session.query(Cluster).get(1) cloned = cluster.clone() assert cloned.id is None assert cloned.sid == cluster.sid assert cloned.filtered == cluster.filtered assert cloned.source == cluster.source assert cloned.quotes.all() == [] cloned = cluster.clone(id=500, filtered=True, source='another') assert cloned.id == 500 assert cloned.id != cluster.id assert cloned.sid == cluster.sid assert cloned.filtered is True assert cloned.filtered != cluster.filtered assert cloned.source == 'another' assert cloned.source != cluster.source assert cloned.quotes.all() == []
def some_urls(some_clusters, some_quotes): """Get a handle to a temporary database filled with a few clusters, quotes, and urls, all wiped on teardown. See the source code if you need details on the clusters', quotes', and urls' exact attributes. """ ids = range(20) with session_scope() as session: quotes = session.query(Quote) # Insert urls in reverse order to check ordering for i in ids[::-1]: quotes.filter_by(sid=i % 10).one()\ .add_url(Url(timestamp=(datetime(year=2008, month=1, day=1) + timedelta(days=i)), frequency=2, url_type='B', url='Url with " and \' {}'.format(i))) return ids
def test_substitution(some_substitutions): """Test base functionality of :class:`~.db.Substitution`.""" model1, model2 = some_substitutions with session_scope() as session: q10 = session.query(Quote).filter_by(sid=10).one() q11 = session.query(Quote).filter_by(sid=11).one() assert q10.substitutions_source.count() == 1 assert q10.substitutions_destination.count() == 0 assert q11.substitutions_source.count() == 0 assert q11.substitutions_destination.count() == 3 # Check relationships for a single substitution. s1 = q10.substitutions_source.first() assert q11.substitutions_destination\ .order_by(Substitution.id).first() == s1 assert s1.source == q10 assert s1.destination == q11 # Check linguistic variables. assert s1.tokens == ('would', 'had') assert s1.lemmas == ('would', 'have') assert s1.tags == ('MD', 'VHD') # We can filter substitutions by mining model. assert session.query(Substitution)\ .filter(Substitution.model == model1).count() == 3 assert session.query(Substitution)\ .filter(Substitution.model == model2).count() == 2 model3 = Model(Time.continuous, Source.majority, Past.last_bin, Durl.all, 1) assert session.query(Substitution)\ .filter(Substitution.model == model3).count() == 0 model4 = Model(Time.discrete, Source.majority, Past.last_bin, Durl.all, 2) assert session.query(Substitution)\ .filter(Substitution.model == model4).count() == 0
def test_clone_quote(some_urls): """Test cloning of a :class:`~.db.Quote`.""" with session_scope() as session: quote = session.query(Quote).get(1) cloned = quote.clone() assert cloned.id is None assert cloned.cluster_id == quote.cluster_id assert cloned.sid == quote.sid assert cloned.filtered == quote.filtered assert cloned.string == quote.string for url in cloned.urls: assert url.quote == cloned # Urls are the same apart from parent quotes for url1, url2 in zip(quote.urls, cloned.urls): url1.quote = None url2.quote = None assert cloned.urls == quote.urls cloned = quote.clone(id=600, filtered=True, cluster_id=125, string='hello') assert cloned.id == 600 assert cloned.id != quote.id assert cloned.cluster_id == 125 assert cloned.cluster_id != quote.cluster_id assert cloned.sid == quote.sid assert cloned.filtered is True assert cloned.filtered != quote.filtered assert cloned.string == 'hello' assert cloned.string != quote.string for url in cloned.urls: assert url.quote == cloned # Urls are the same apart from parent quotes for url1, url2 in zip(quote.urls, cloned.urls): url1.quote = None url2.quote = None assert cloned.urls == quote.urls
def filter_clusters(limit=None): """Filter the whole MemeTracker dataset by copying all valid :class:`~.db.Cluster`\ s and :class:`~.db.Quote`\ s and setting their `filtered` attributes to `True`. Iterate through all the MemeTracker :class:`~.db.Cluster`\ s, and filter each of them to see if it's worth keeping. If a :class:`~.db.Cluster` is to be kept, the function creates a copy of it and all of its kept :class:`~.db.Quote`\ s, marking them as filtered. Progress of this operation is printed to stdout. Once the operation finishes, a VACUUM and an ANALYZE operation are run on the database so that it recomputes its optimisations. Parameters ---------- limit : int, optional If not `None`, stop filtering after `limit` clusters have been seen (useful for testing purposes). Raises ------ AlreadyFiltered If there are already some filtered :class:`~.db.Cluster`\ s or :class:`~.db.Quote`\ s stored in the database (indicating another filtering operation has already been completed, or started and aborted). """ from brainscopypaste.db import Session, Cluster, save_by_copy logger.info('Filtering memetracker clusters') if limit is not None: logger.info('Filtering is limited to %s clusters', limit) click.echo('Filtering all clusters{}...' .format('' if limit is None else ' (limit={})'.format(limit))) # Check this isn't already done. with session_scope() as session: if session.query(Cluster)\ .filter(Cluster.filtered.is_(True)).count() > 0: raise AlreadyFiltered('There are already some filtered ' 'clusters, aborting.') query = session.query(Cluster.id) if limit is not None: query = query.limit(limit) cluster_ids = [id for (id,) in query] logger.info('Got %s clusters to filter', len(cluster_ids)) # Filter. objects = {'clusters': [], 'quotes': []} for cluster_id in ProgressBar()(cluster_ids): with session_scope() as session: cluster = session.query(Cluster).get(cluster_id) fcluster = cluster.filter() if fcluster is not None: logger.debug('Cluster #%s is kept with %s quotes', cluster.sid, fcluster.size) objects['clusters'].append(fcluster) objects['quotes'].extend(fcluster.quotes) else: logger.debug('Cluster #%s is dropped', cluster.sid) click.secho('OK', fg='green', bold=True) logger.info('Kept %s clusters and %s quotes after filtering', len(objects['clusters']), len(objects['quotes'])) # Save. logger.info('Saving filtered clusters to database') save_by_copy(**objects) # Vacuum analyze. logger.info('Vacuuming and analyzing database') click.echo('Vacuuming and analyzing... ', nl=False) execute_raw(Session.kw['bind'], 'VACUUM ANALYZE') click.secho('OK', fg='green', bold=True)
def test_quote(some_quotes): """Test base functionality of :class:`~.db.Quote`.""" # Test empty quote attributes. quote = Quote() assert quote.size == 0 assert quote.frequency == 0 with pytest.raises(ValueError) as excinfo: quote.span assert 'No urls' in str(excinfo.value) assert quote.urls == [] with pytest.raises(ValueError) as excinfo: quote.tags assert 'No string' in str(excinfo.value) with pytest.raises(ValueError) as excinfo: quote.tokens assert 'No string' in str(excinfo.value) with pytest.raises(ValueError) as excinfo: quote.lemmas assert 'No string' in str(excinfo.value) # Test quotes from database. with session_scope() as session: assert session.query(Quote).count() == 10 assert session.query(Quote).filter_by(sid=0).one().cluster.sid == 0 assert session.query(Quote).filter_by(sid=2).one().cluster.sid == 2 assert session.query(Quote).filter_by(sid=4).one().cluster.sid == 4 assert session.query(Quote).filter_by(sid=6).one().cluster.sid == 1 assert session.query(Quote).filter_by(sid=6).one().tokens == \ ('some', 'quote', 'to', 'tokenize', '6') assert session.query(Quote).filter_by(sid=6).one().tags == \ ('DT', 'NN', 'TO', 'VV', 'CD') assert session.query(Quote).filter_by(sid=6).one().lemmas == \ ('some', 'quote', 'to', 'tokenize', '6') assert set([quote.sid for quote in session.query(Cluster).filter_by(sid=3).one().quotes]) == \ set([3, 8]) q1 = session.query(Quote).filter_by(sid=1).one() assert session.query(Cluster.sid)\ .filter(Cluster.quotes.contains(q1)).one() == (1,) q7 = session.query(Quote).filter_by(sid=7).one() assert session.query(Cluster.sid)\ .filter(Cluster.quotes.contains(q7)).one() == (2,) assert session.query(Quote).filter_by(sid=0).one().size == 0 assert session.query(Quote).filter_by(sid=0).one().frequency == 0 with pytest.raises(ValueError) as excinfo: session.query(Quote).filter_by(sid=0).one().span assert 'No urls' in str(excinfo.value) assert session.query(Quote).filter_by(sid=0).one().urls == [] assert session.query(Cluster).filter_by(sid=3).one().size == 2 assert session.query(Cluster).filter_by(sid=3).one().size_urls == 0 assert session.query(Cluster).filter_by(sid=3).one().frequency == 0 with pytest.raises(ValueError) as excinfo: session.query(Cluster).filter_by(sid=3).one().span assert 'No urls' in str(excinfo.value) q0 = session.query(Quote).filter_by(sid=0).one() assert q0.format_copy() == ('{}'.format(q0.id) + "\t1\t0\tFalse\tSome quote to " "tokenize 0\t{}\t{}\t{}\t{}")
def mine_substitutions_with_model(model, limit=None): """Mine all substitutions in the MemeTracker dataset conforming to `model`. Iterates through the whole MemeTracker dataset to find all substitutions that are considered valid by `model`, and save the results to the database. The MemeTracker dataset must have been loaded and filtered previously, or an excetion will be raised (see :ref:`usage` or :mod:`.cli` for more about that). Mined substitutions are saved each time the function moves to a new cluster, and progress is printed to stdout. The number of substitutions seen and the number of substitutions kept (i.e. validated by :meth:`SubstitutionValidatorMixin.validate`) are also printed to stdout. Parameters ---------- model : :class:`Model` The substitution model to use for mining. limit : int, optional If not `None` (default), mining will stop after `limit` clusters have been examined. Raises ------ Exception If no filtered clusters are found in the database, or if there already are some substitutions from model `model` in the database. """ from brainscopypaste.db import Cluster, Substitution logger.info('Mining clusters for substitutions') if limit is not None: logger.info('Mining is limited to %s clusters', limit) click.echo('Mining clusters for substitutions with {}{}...' .format(model, '' if limit is None else ' (limit={})'.format(limit))) # Check we haven't already mined substitutions with this model. with session_scope() as session: substitution_count = session.query(Substitution)\ .filter(Substitution.model == model).count() if substitution_count != 0: raise Exception(('The database already contains substitutions ' 'mined with this model ({} - {} substitutions). ' 'You should drop these before doing anything ' 'else.'.format(model, substitution_count))) # Check clusters have been filtered. with session_scope() as session: if session.query(Cluster)\ .filter(Cluster.filtered.is_(True)).count() == 0: raise Exception('Found no filtered clusters, aborting.') query = session.query(Cluster.id).filter(Cluster.filtered.is_(True)) if limit is not None: query = query.limit(limit) cluster_ids = [id for (id,) in query] logger.info('Got %s clusters to mine', len(cluster_ids)) # Mine. seen = 0 kept = 0 for cluster_id in ProgressBar()(cluster_ids): model.drop_caches() with session_scope() as session: cluster = session.query(Cluster).get(cluster_id) for substitution in cluster.substitutions(model): seen += 1 if substitution.validate(): logger.debug('Found valid substitution in cluster #%s', cluster.sid) kept += 1 session.commit() else: logger.debug('Dropping substitution from cluster #%s', cluster.sid) session.rollback() # Sanity check. This session business is tricky. with session_scope() as session: assert session.query(Substitution)\ .filter(Substitution.model == model).count() == kept click.secho('OK', fg='green', bold=True) logger.info('Seen %s candidate substitutions, kept %s', seen, kept) click.echo('Seen {} candidate substitutions, kept {}.'.format(seen, kept))