def purged_chain_cover_txn(txn: LoggingTransaction) -> int: # The event ID from events will be null if the chain ID / sequence # number points to a purged event. sql = """ SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL FROM event_auth_chains LEFT JOIN events AS e USING (event_id) WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ? """ txn.execute(sql, (current_event_id, batch_size)) rows = txn.fetchall() if not rows: return 0 # The event IDs and chain IDs / sequence numbers where the event has # been purged. unreferenced_event_ids = [] unreferenced_chain_id_tuples = [] event_id = "" for event_id, chain_id, sequence_number, has_event in rows: if not has_event: unreferenced_event_ids.append((event_id, )) unreferenced_chain_id_tuples.append( (chain_id, sequence_number)) # Delete the unreferenced auth chains from event_auth_chain_links and # event_auth_chains. txn.executemany( """ DELETE FROM event_auth_chains WHERE event_id = ? """, unreferenced_event_ids, ) # We should also delete matching target_*, but there is no index on # target_chain_id. Hopefully any purged events are due to a room # being fully purged and they will be removed from the origin_* # searches. txn.executemany( """ DELETE FROM event_auth_chain_links WHERE origin_chain_id = ? AND origin_sequence_number = ? """, unreferenced_chain_id_tuples, ) progress = { "current_event_id": event_id, } self.db_pool.updates._background_update_progress_txn( txn, "purged_chain_cover", progress) return len(rows)
def get_updated_account_data_for_user_txn( txn: LoggingTransaction, ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: sql = ("SELECT account_data_type, content FROM account_data" " WHERE user_id = ? AND stream_id > ?") txn.execute(sql, (user_id, stream_id)) global_account_data = {row[0]: db_to_json(row[1]) for row in txn} sql = ( "SELECT room_id, account_data_type, content FROM room_account_data" " WHERE user_id = ? AND stream_id > ?") txn.execute(sql, (user_id, stream_id)) account_data_by_room: Dict[str, Dict[str, JsonDict]] = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = db_to_json(row[2]) return global_account_data, account_data_by_room
def get_all_updated_pushers_rows_txn( txn: LoggingTransaction, ) -> Tuple[List[Tuple[int, tuple]], int, bool]: sql = """ SELECT id, user_name, app_id, pushkey FROM pushers WHERE ? < id AND id <= ? ORDER BY id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) updates = cast( List[Tuple[int, tuple]], [ (stream_id, (user_name, app_id, pushkey, False)) for stream_id, user_name, app_id, pushkey in txn ], ) sql = """ SELECT stream_id, user_id, app_id, pushkey FROM deleted_pushers WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) updates.extend( (stream_id, (user_name, app_id, pushkey, True)) for stream_id, user_name, app_id, pushkey in txn ) updates.sort() # Sort so that they're ordered by stream id limited = False upper_bound = current_id if len(updates) >= limit: limited = True upper_bound = updates[-1][0] return updates, upper_bound, limited
def _send_invalidation_to_replication( self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]) -> None: """Notifies replication that given cache has been invalidated. Note that this does *not* invalidate the cache locally. Args: txn cache_name keys: Entry to invalidate. If None will invalidate all. """ if cache_name == CURRENT_STATE_CACHE_NAME and keys is None: raise Exception( "Can't stream invalidate all with magic current state cache") if isinstance(self.database_engine, PostgresEngine): # get_next() returns a context manager which is designed to wrap # the transaction. However, we want to only get an ID when we want # to use it, here, so we need to call __enter__ manually, and have # __exit__ called after the transaction finishes. stream_id = self._cache_id_gen.get_next_txn(txn) txn.call_after(self.hs.get_notifier().on_new_replication_data) if keys is not None: keys = list(keys) self.db_pool.simple_insert_txn( txn, table="cache_invalidation_stream_by_instance", values={ "stream_id": stream_id, "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, "invalidation_ts": self._clock.time_msec(), }, )
def _add_device_outbound_poke_to_stream_txn( self, txn: LoggingTransaction, user_id: str, device_ids: Collection[str], hosts: List[str], stream_ids: List[str], context: Dict[str, str], ): for host in hosts: txn.call_after( self._device_list_federation_stream_cache.entity_has_changed, host, stream_ids[-1], ) now = self._clock.time_msec() next_stream_id = iter(stream_ids) self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", values=[ { "destination": destination, "stream_id": next(next_stream_id), "user_id": user_id, "device_id": device_id, "sent": False, "ts": now, "opentracing_context": json_encoder.encode(context) if whitelisted_homeserver(destination) else "{}", } for destination in hosts for device_id in device_ids ], )
def _insert_graph_receipt_txn( self, txn: LoggingTransaction, room_id: str, receipt_type: str, user_id: str, event_ids: List[str], data: JsonDict, ) -> None: assert self._can_write_to_receipts txn.call_after( self._get_receipts_for_user_with_orderings.invalidate, (user_id, receipt_type), ) # FIXME: This shouldn't invalidate the whole cache txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) self.db_pool.simple_delete_txn( txn, table="receipts_graph", keyvalues={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, }, ) self.db_pool.simple_insert_txn( txn, table="receipts_graph", values={ "room_id": room_id, "receipt_type": receipt_type, "user_id": user_id, "event_ids": json_encoder.encode(event_ids), "data": json_encoder.encode(data), }, )
def _claim_e2e_one_time_key_simple( txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str) -> Optional[Tuple[str, str]]: """Claim OTK for device for DBs that don't support RETURNING. Returns: A tuple of key name (algorithm + key ID) and key JSON, if an OTK was found. """ sql = """ SELECT key_id, key_json FROM e2e_one_time_keys_json WHERE user_id = ? AND device_id = ? AND algorithm = ? LIMIT 1 """ txn.execute(sql, (user_id, device_id, algorithm)) otk_row = txn.fetchone() if otk_row is None: return None key_id, key_json = otk_row self.db_pool.simple_delete_one_txn( txn, table="e2e_one_time_keys_json", keyvalues={ "user_id": user_id, "device_id": device_id, "algorithm": algorithm, "key_id": key_id, }, ) self._invalidate_cache_and_stream(txn, self.count_e2e_one_time_keys, (user_id, device_id)) return f"{algorithm}:{key_id}", key_json
def _delete_account_data_for_deactivated_users_txn( txn: LoggingTransaction, ) -> int: sql = """ SELECT name FROM users WHERE deactivated = ? and name > ? ORDER BY name ASC LIMIT ? """ txn.execute(sql, (1, last_user, batch_size)) users = [row[0] for row in txn] for user in users: self._purge_account_data_for_user_txn(txn, user_id=user) if users: self.db_pool.updates._background_update_progress_txn( txn, "delete_account_data_for_deactivated_users", {"last_user": users[-1]}, ) return len(users)
def get_all_push_rule_updates_txn( txn: LoggingTransaction, ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]: sql = """ SELECT stream_id, user_id FROM push_rules_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) updates = cast( List[Tuple[int, Tuple[str]]], [(stream_id, (user_id, )) for stream_id, user_id in txn], ) limited = False upper_bound = current_id if len(updates) == limit: limited = True upper_bound = updates[-1][0] return updates, upper_bound, limited
def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str: txn.execute( "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", (user_id, ), ) current_version = cast(Tuple[Optional[int]], txn.fetchone())[0] if current_version is None: current_version = 0 new_version = current_version + 1 self.db_pool.simple_insert_txn( txn, table="e2e_room_keys_versions", values={ "user_id": user_id, "version": new_version, "algorithm": info["algorithm"], "auth_data": json_encoder.encode(info["auth_data"]), }, ) return str(new_version)
def _get_recent_references_for_event_txn( txn: LoggingTransaction, ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: txn.execute(sql, where_args + [limit + 1]) last_topo_id = None last_stream_id = None events = [] for row in txn: # Do not include edits for redacted events as they leak event # content. if not is_redacted or row[1] != RelationTypes.REPLACE: events.append(_RelatedEvent(row[0], row[2])) last_topo_id = row[3] last_stream_id = row[4] # If there are more events, generate the next pagination key. next_token = None if len(events) > limit and last_topo_id and last_stream_id: next_key = RoomStreamToken(last_topo_id, last_stream_id) if from_token: next_token = from_token.copy_and_replace( StreamKeyType.ROOM, next_key) else: next_token = StreamToken( room_key=next_key, presence_key=0, typing_key=0, receipt_key=0, account_data_key=0, push_rules_key=0, to_device_key=0, device_list_key=0, groups_key=0, ) return events[:limit], next_token
def _add_push_rule_highest_priority_txn( self, txn: LoggingTransaction, stream_id: int, event_stream_ordering: int, user_id: str, rule_id: str, priority_class: int, conditions_json: str, actions_json: str, ) -> None: # Lock the table since otherwise we'll have annoying races between the # SELECT here and the UPSERT below. self.database_engine.lock_table(txn, "push_rules") # find the highest priority rule in that class sql = ("SELECT COUNT(*), MAX(priority) FROM push_rules" " WHERE user_name = ? and priority_class = ?") txn.execute(sql, (user_id, priority_class)) res = txn.fetchall() (how_many, highest_prio) = res[0] new_prio = 0 if how_many > 0: new_prio = highest_prio + 1 self._upsert_push_rule_txn( txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio, conditions_json, actions_json, )
def _graph_to_linear( self, txn: LoggingTransaction, room_id: str, event_ids: List[str] ) -> str: """ Generate a linearized event from a list of events (i.e. a list of forward extremities in the room). This should allow for calculation of the correct read receipt even if servers have different event ordering. Args: txn: The transaction room_id: The room ID the events are in. event_ids: The list of event IDs to linearize. Returns: The linearized event ID. """ # TODO: Make this better. clause, args = make_in_list_sql_clause( self.database_engine, "event_id", event_ids ) sql = """ SELECT event_id WHERE room_id = ? AND stream_ordering IN ( SELECT max(stream_ordering) WHERE %s ) """ % ( clause, ) txn.execute(sql, [room_id] + list(args)) rows = txn.fetchall() if rows: return rows[0][0] else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
def f(txn: LoggingTransaction) -> List[_EventDictReturn]: # To handle tokens with a non-empty instance_map we fetch more # results than necessary and then filter down min_from_id = from_key.stream max_to_id = to_key.get_max_stream_pos() sql = """ SELECT m.event_id, instance_name, topological_ordering, stream_ordering FROM events AS e, room_memberships AS m WHERE e.event_id = m.event_id AND m.user_id = ? AND e.stream_ordering > ? AND e.stream_ordering <= ? ORDER BY e.stream_ordering ASC """ txn.execute( sql, ( user_id, min_from_id, max_to_id, ), ) rows = [ _EventDictReturn(event_id, None, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( from_key, to_key, instance_name, topological_ordering, stream_ordering, ) ] return rows
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE stream_id > ? AND stream_id <= ? AND """ clause, args = make_in_list_sql_clause( self.database_engine, "room_id", room_ids ) txn.execute(sql + clause, [from_key, to_key] + list(args)) else: sql = """ SELECT * FROM receipts_linearized WHERE stream_id <= ? AND """ clause, args = make_in_list_sql_clause( self.database_engine, "room_id", room_ids ) txn.execute(sql + clause, [to_key] + list(args)) return self.db_pool.cursor_to_dict(txn)
def _delete_pushers(txn: LoggingTransaction) -> int: sql = """ SELECT p.id, p.user_name, p.app_id, p.pushkey FROM pushers AS p LEFT JOIN user_threepids AS t ON t.user_id = p.user_name AND t.medium = 'email' AND t.address = p.pushkey WHERE t.user_id is NULL AND p.app_id = 'm.email' AND p.id > ? ORDER BY p.id ASC LIMIT ? """ txn.execute(sql, (last_pusher, batch_size)) rows = txn.fetchall() last = None num_deleted = 0 for row in rows: last = row[0] num_deleted += 1 self.db_pool.simple_delete_txn( txn, "pushers", {"user_name": row[1], "app_id": row[2], "pushkey": row[3]}, ) if last is not None: self.db_pool.updates._background_update_progress_txn( txn, "remove_deleted_email_pushers", {"last_pusher": last} ) return num_deleted
def insert(txn: LoggingTransaction) -> None: sql = ( "INSERT INTO event_search (event_id, room_id, key," " sender, vector, origin_server_ts, stream_ordering)" " VALUES (?,?,?,?,to_tsvector('english', ?),?,?)") rows_dict = [] for row in rows: d = dict(zip(headers, row)) if "\0" in d["value"]: logger.warning("dropping search row %s", d) else: rows_dict.append(d) txn.executemany( sql, [( row["event_id"], row["room_id"], row["key"], row["sender"], row["value"], row["origin_server_ts"], row["stream_ordering"], ) for row in rows_dict], ) self.postgres_store.db_pool.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": "event_search"}, updatevalues={ "forward_rowid": forward_chunk, "backward_rowid": backward_chunk, }, )
def _set_destination_retry_timings_native( self, txn: LoggingTransaction, destination: str, failure_ts: Optional[int], retry_last_ts: int, retry_interval: int, ) -> None: assert self.database_engine.can_native_upsert # Upsert retry time interval if retry_interval is zero (i.e. we're # resetting it) or greater than the existing retry interval. # # WARNING: This is executed in autocommit, so we shouldn't add any more # SQL calls in here (without being very careful). sql = """ INSERT INTO destinations ( destination, failure_ts, retry_last_ts, retry_interval ) VALUES (?, ?, ?, ?) ON CONFLICT (destination) DO UPDATE SET failure_ts = EXCLUDED.failure_ts, retry_last_ts = EXCLUDED.retry_last_ts, retry_interval = EXCLUDED.retry_interval WHERE EXCLUDED.retry_interval = 0 OR destinations.retry_interval IS NULL OR destinations.retry_interval < EXCLUDED.retry_interval """ txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval)) self._invalidate_cache_and_stream(txn, self.get_destination_retry_timings, (destination, ))
def remove_old_push_actions_that_have_rotated_txn( txn: LoggingTransaction, ) -> bool: # We don't want to clear out too much at a time, so we bound our # deletes. batch_size = 10000 txn.execute( """ SELECT stream_ordering FROM event_push_actions WHERE stream_ordering < ? AND highlight = 0 ORDER BY stream_ordering ASC LIMIT 1 OFFSET ? """, ( max_stream_ordering_to_delete, batch_size, ), ) stream_row = txn.fetchone() if stream_row: (stream_ordering, ) = stream_row else: stream_ordering = max_stream_ordering_to_delete txn.execute( """ DELETE FROM event_push_actions WHERE stream_ordering < ? AND highlight = 0 """, (stream_ordering, ), ) logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) return txn.rowcount < batch_size
def _store_destination_rooms_entries_txn( self, txn: LoggingTransaction, destinations: Iterable[str], room_id: str, stream_ordering: int, ) -> None: # ensure we have a `destinations` row for this destination, as there is # a foreign key constraint. if isinstance(self.database_engine, PostgresEngine): q = """ INSERT INTO destinations (destination) VALUES (?) ON CONFLICT DO NOTHING; """ elif isinstance(self.database_engine, Sqlite3Engine): q = """ INSERT OR IGNORE INTO destinations (destination) VALUES (?); """ else: raise RuntimeError("Unknown database engine") txn.execute_batch(q, ((destination, ) for destination in destinations)) rows = [(destination, room_id) for destination in destinations] self.db_pool.simple_upsert_many_txn( txn, "destination_rooms", ["destination", "room_id"], rows, ["stream_ordering"], [(stream_ordering, )] * len(rows), )
def _update_remote_device_list_cache_entry_txn( self, txn: LoggingTransaction, user_id: str, device_id: str, content: JsonDict, stream_id: int, ) -> None: if content.get("deleted"): self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, ) txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) else: self.db_pool.simple_upsert_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, values={"content": json_encoder.encode(content)}, # we don't need to lock, because we assume we are the only thread # updating this user's devices. lock=False, ) txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id)) txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) txn.call_after( self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) self.db_pool.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, values={"stream_id": stream_id}, # again, we can assume we are the only thread updating this user's # extremity. lock=False, )
def _do_txn(txn: LoggingTransaction) -> int: sql = ("SELECT filter_id FROM user_filters " "WHERE user_id = ? AND filter_json = ?") txn.execute(sql, (user_localpart, bytearray(def_json))) filter_id_response = txn.fetchone() if filter_id_response is not None: return filter_id_response[0] sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" txn.execute(sql, (user_localpart, )) max_id = txn.fetchone()[0] # type: ignore[index] if max_id is None: filter_id = 0 else: filter_id = max_id + 1 sql = ("INSERT INTO user_filters (user_id, filter_id, filter_json)" "VALUES(?, ?, ?)") txn.execute(sql, (user_localpart, filter_id, bytearray(def_json))) return filter_id
def _update_presence_txn( self, txn: LoggingTransaction, stream_orderings: List[int], presence_states: List[UserPresenceState], ) -> None: for stream_id, state in zip(stream_orderings, presence_states): txn.call_after(self.presence_stream_cache.entity_has_changed, state.user_id, stream_id) txn.call_after(self._get_presence_for_user.invalidate, (state.user_id, )) # Delete old rows to stop database from getting really big sql = "DELETE FROM presence_stream WHERE stream_id < ? AND " for states in batch_iter(presence_states, 50): clause, args = make_in_list_sql_clause(self.database_engine, "user_id", [s.user_id for s in states]) txn.execute(sql + clause, [stream_id] + list(args)) # Actually insert new rows self.db_pool.simple_insert_many_txn( txn, table="presence_stream", keys=( "stream_id", "user_id", "state", "last_active_ts", "last_federation_update_ts", "last_user_sync_ts", "status_msg", "currently_active", "instance_name", ), values=[( stream_id, state.user_id, state.state, state.last_active_ts, state.last_federation_update_ts, state.last_user_sync_ts, state.status_msg, state.currently_active, self._instance_name, ) for stream_id, state in zip(stream_orderings, presence_states)], )
def _insert(txn: LoggingTransaction) -> None: txn.execute( "INSERT INTO foobar VALUES (?, ?)", ( stream_id, instance_name, ), ) txn.execute("SELECT setval('foobar_seq', ?)", (stream_id, )) txn.execute( """ INSERT INTO stream_positions VALUES ('test_stream', ?, ?) ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? """, (instance_name, stream_id, stream_id), )
def _remove_old_push_actions_before_txn(self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int) -> None: """ Purges old push actions for a user and room before a given stream_ordering. We however keep a months worth of highlighted notifications, so that users can still get a list of recent highlights. Args: txn: The transcation room_id: Room ID to delete from user_id: user ID to delete for stream_ordering: The lowest stream ordering which will not be deleted. """ txn.call_after( self.get_unread_event_push_actions_by_room_for_user.invalidate, (room_id, user_id), ) # We need to join on the events table to get the received_ts for # event_push_actions and sqlite won't let us use a join in a delete so # we can't just delete where received_ts < x. Furthermore we can # only identify event_push_actions by a tuple of room_id, event_id # we we can't use a subquery. # Instead, we look up the stream ordering for the last event in that # room received before the threshold time and delete event_push_actions # in the room with a stream_odering before that. txn.execute( "DELETE FROM event_push_actions " " WHERE user_id = ? AND room_id = ? AND " " stream_ordering <= ?" " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), ) txn.execute( """ DELETE FROM event_push_summary WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? """, (room_id, user_id, stream_ordering), )
def _setup_db(self, txn: LoggingTransaction) -> None: txn.execute("CREATE SEQUENCE foobar_seq") txn.execute(""" CREATE TABLE foobar1 ( stream_id BIGINT NOT NULL, instance_name TEXT NOT NULL, data TEXT ); """) txn.execute(""" CREATE TABLE foobar2 ( stream_id BIGINT NOT NULL, instance_name TEXT NOT NULL, data TEXT ); """)
def _get_unread_counts_by_pos_txn(self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int) -> NotifCounts: sql = ("SELECT" " COUNT(CASE WHEN notif = 1 THEN 1 END)," " COUNT(CASE WHEN highlight = 1 THEN 1 END)," " COUNT(CASE WHEN unread = 1 THEN 1 END)" " FROM event_push_actions ea" " WHERE user_id = ?" " AND room_id = ?" " AND stream_ordering > ?") txn.execute(sql, (user_id, room_id, stream_ordering)) row = txn.fetchone() (notif_count, highlight_count, unread_count) = (0, 0, 0) if row: (notif_count, highlight_count, unread_count) = row txn.execute( """ SELECT notif_count, unread_count FROM event_push_summary WHERE room_id = ? AND user_id = ? AND stream_ordering > ? """, (room_id, user_id, stream_ordering), ) row = txn.fetchone() if row: notif_count += row[0] if row[1] is not None: # The unread_count column of event_push_summary is NULLable, so we need # to make sure we don't try increasing the unread counts if it's NULL # for this row. unread_count += row[1] return NotifCounts( notify_count=notif_count, unread_count=unread_count, highlight_count=highlight_count, )
def _update_remote_device_list_cache_txn(self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int) -> None: self.db_pool.simple_delete_txn(txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}) self.db_pool.simple_insert_many_txn( txn, table="device_lists_remote_cache", values=[{ "user_id": user_id, "device_id": content["device_id"], "content": json_encoder.encode(content), } for content in devices], ) txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id, )) txn.call_after(self._get_cached_user_device.invalidate_many, (user_id, )) txn.call_after( self.get_device_list_last_stream_id_for_remote.invalidate, (user_id, )) self.db_pool.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, values={"stream_id": stream_id}, # we don't need to lock, because we can assume we are the only thread # updating this user's extremity. lock=False, ) # If we're replacing the remote user's device list cache presumably # we've done a full resync, so we remove the entry that says we need # to resync self.db_pool.simple_delete_txn( txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, )
def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]: sql = "SELECT user_id FROM %s LIMIT %s" % ( TEMP_TABLE + "_users", str(batch_size), ) txn.execute(sql) user_result = cast(List[Tuple[str]], txn.fetchall()) if not user_result: return None users_to_work_on = [x[0] for x in user_result] # Get how many are left to process, so we can give status on how # far we are in processing sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users" txn.execute(sql) count_result = txn.fetchone() assert count_result is not None progress["remaining"] = count_result[0] return users_to_work_on
def get_url_cache_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: # get the most recently cached result (relative to the given ts) sql = ( "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts <= ?" " ORDER BY download_ts DESC LIMIT 1" ) txn.execute(sql, (url, ts)) row = txn.fetchone() if not row: # ...or if we've requested a timestamp older than the oldest # copy in the cache, return the oldest copy (if any) sql = ( "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts > ?" " ORDER BY download_ts ASC LIMIT 1" ) txn.execute(sql, (url, ts)) row = txn.fetchone() if not row: return None return dict( zip( ( "response_code", "etag", "expires_ts", "og", "media_id", "download_ts", ), row, ) )