def start_replication(self, hs): """Helper method to start a replication connection to the remote server using TCP. """ if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, ) # First let's ensure that we have a ReplicationStreamer started. hs.get_replication_streamer() # We need two connections to redis, one for the subscription stream and # one to send commands to (as you can't send further redis commands to a # connection after SUBSCRIBE is called). # First create the connection for sending commands. outbound_redis_connection = hs.get_outbound_redis_connection() # Now create the factory/connection for the subscription stream. self._factory = RedisDirectTcpReplicationClientFactory( hs, outbound_redis_connection) hs.get_reactor().connectTCP( hs.config.redis.redis_host.encode(), hs.config.redis.redis_port, self._factory, ) else: client_name = hs.get_instance_name() self._factory = DirectTcpReplicationClientFactory( hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host.encode(), port, self._factory)
def prepare(self, reactor, clock, hs): db_config = hs.config.database.get_single_database() self.master_store = self.hs.get_datastore() self.storage = hs.get_storage() database = hs.get_datastores().databases[0] self.slaved_store = self.STORE_TYPE( database, make_conn(db_config, database.engine), self.hs) self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = hs.get_replication_streamer() # We now do some gut wrenching so that we have a client that is based # off of the slave store rather than the main store. self.replication_handler = ReplicationCommandHandler(self.hs) self.replication_handler._instance_name = "worker" self.replication_handler._replication_data_handler = ReplicationDataHandler( self.slaved_store) client_factory = DirectTcpReplicationClientFactory( self.hs, "client_name", self.replication_handler) client_factory.handler = self.replication_handler server = server_factory.buildProtocol(None) client = client_factory.buildProtocol(None) client.makeConnection(FakeTransport(server, reactor)) self.server_to_client_transport = FakeTransport(client, reactor) server.makeConnection(self.server_to_client_transport)
def start_replication(self, hs): """Helper method to start a replication connection to the remote server using TCP. """ if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, ) import txredisapi logger.info( "Connecting to redis (host=%r port=%r DBID=%r)", hs.config.redis_host, hs.config.redis_port, hs.config.redis_dbid, ) # We need two connections to redis, one for the subscription stream and # one to send commands to (as you can't send further redis commands to a # connection after SUBSCRIBE is called). # First create the connection for sending commands. outbound_redis_connection = txredisapi.lazyConnection( host=hs.config.redis_host, port=hs.config.redis_port, dbid=hs.config.redis_dbid, password=hs.config.redis.redis_password, reconnect=True, ) # Now create the factory/connection for the subscription stream. self._factory = RedisDirectTcpReplicationClientFactory( hs, outbound_redis_connection ) hs.get_reactor().connectTCP( hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, ) else: client_name = hs.get_instance_name() self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self._factory)
def start_replication(self, hs: "HomeServer") -> None: """Helper method to start replication.""" if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, ) # First let's ensure that we have a ReplicationStreamer started. hs.get_replication_streamer() # We need two connections to redis, one for the subscription stream and # one to send commands to (as you can't send further redis commands to a # connection after SUBSCRIBE is called). # First create the connection for sending commands. outbound_redis_connection = hs.get_outbound_redis_connection() # Now create the factory/connection for the subscription stream. self._factory = RedisDirectTcpReplicationClientFactory( hs, outbound_redis_connection, channel_names=self._channels_to_subscribe_to, ) hs.get_reactor().connectTCP( hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, timeout=30, bindAddress=None, ) else: client_name = hs.get_instance_name() self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) host = hs.config.worker.worker_replication_host port = hs.config.worker.worker_replication_port hs.get_reactor().connectTCP( host, port, self._factory, timeout=30, bindAddress=None, )
class ReplicationCommandHandler: """Handles incoming commands from replication as well as sending commands back out to connections. """ def __init__(self, hs): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() self._store = hs.get_datastore() self._notifier = hs.get_notifier() self._clock = hs.get_clock() self._instance_id = hs.get_instance_id() self._instance_name = hs.get_instance_name() # Set of streams that we've caught up with. self._streams_connected = set() # type: Set[str] self._streams = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # type: Dict[str, Stream] self._position_linearizer = Linearizer("replication_position", clock=self._clock) # Map of stream to batched updates. See RdataCommand for info on how # batching works. self._pending_batches = {} # type: Dict[str, List[Any]] # The factory used to create connections. self._factory = None # type: Optional[ReconnectingClientFactory] # The currently connected connections. (The list of places we need to send # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] LaterGauge( "synapse_replication_tcp_resource_total_connections", "", [], lambda: len(self._connections), ) self._is_master = hs.config.worker_app is None self._federation_sender = None if self._is_master and not hs.config.send_federation: self._federation_sender = hs.get_federation_sender() self._server_notices_sender = None if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() def start_replication(self, hs): """Helper method to start a replication connection to the remote server using TCP. """ if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, ) import txredisapi logger.info( "Connecting to redis (host=%r port=%r)", hs.config.redis_host, hs.config.redis_port, ) # First let's ensure that we have a ReplicationStreamer started. hs.get_replication_streamer() # We need two connections to redis, one for the subscription stream and # one to send commands to (as you can't send further redis commands to a # connection after SUBSCRIBE is called). # First create the connection for sending commands. outbound_redis_connection = txredisapi.lazyConnection( host=hs.config.redis_host, port=hs.config.redis_port, password=hs.config.redis.redis_password, reconnect=True, ) # Now create the factory/connection for the subscription stream. self._factory = RedisDirectTcpReplicationClientFactory( hs, outbound_redis_connection) hs.get_reactor().connectTCP( hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, ) else: client_name = hs.get_instance_name() self._factory = DirectTcpReplicationClientFactory( hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port hs.get_reactor().connectTCP(host, port, self._factory) async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): # We only want to announce positions by the writer of the streams. # Currently this is just the master process. if not self._is_master: return for stream_name, stream in self._streams.items(): current_token = stream.current_token() self.send_command( PositionCommand(stream_name, self._instance_name, current_token)) async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): user_sync_counter.inc() if self._is_master: await self._presence_handler.update_external_syncs_row( cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms) async def on_CLEAR_USER_SYNC(self, conn: AbstractConnection, cmd: ClearUserSyncsCommand): if self._is_master: await self._presence_handler.update_external_syncs_clear( cmd.instance_id) async def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand): federation_ack_counter.inc() if self._federation_sender: self._federation_sender.federation_ack(cmd.token) async def on_REMOVE_PUSHER(self, conn: AbstractConnection, cmd: RemovePusherCommand): remove_pusher_counter.inc() if self._is_master: await self._store.delete_pusher_by_app_id_pushkey_user_id( app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id) self._notifier.on_new_replication_data() async def on_INVALIDATE_CACHE(self, conn: AbstractConnection, cmd: InvalidateCacheCommand): invalidate_cache_counter.inc() if self._is_master: # We invalidate the cache locally, but then also stream that to other # workers. await self._store.invalidate_cache_and_stream( cmd.cache_func, tuple(cmd.keys)) async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): user_ip_cache_counter.inc() if self._is_master: await self._store.insert_client_ip( cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, cmd.last_seen, ) if self._server_notices_sender: await self._server_notices_sender.on_user_ip(cmd.user_id) async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes return stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() try: row = STREAMS_MAP[stream_name].parse_row(cmd.row) except Exception: logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) raise # We linearize here for two reasons: # 1. so we don't try and concurrently handle multiple rows for the # same stream, and # 2. so we don't race with getting a POSITION command and fetching # missing RDATA. with await self._position_linearizer.queue(cmd.stream_name): if stream_name not in self._streams_connected: # If the stream isn't marked as connected then we haven't seen a # `POSITION` command yet, and so we may have missed some rows. # Let's drop the row for now, on the assumption we'll receive a # `POSITION` soon and we'll catch up correctly then. logger.debug( "Discarding RDATA for unconnected stream %s -> %s", stream_name, cmd.token, ) return if cmd.token is None: # I.e. this is part of a batch of updates for this stream (in # which case batch until we get an update for the stream with a non # None token). self._pending_batches.setdefault(stream_name, []).append(row) else: # Check if this is the last of a batch of updates rows = self._pending_batches.pop(stream_name, []) rows.append(row) await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) async def on_rdata(self, stream_name: str, instance_name: str, token: int, rows: list): """Called to handle a batch of replication data with a given stream token. Args: stream_name: name of the replication stream for this batch of rows instance_name: the instance that wrote the rows. token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ logger.debug("Received rdata %s -> %s", stream_name, token) await self._replication_data_handler.on_rdata(stream_name, instance_name, token, rows) async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: # Ignore POSITION that are just our own echoes return stream = self._streams.get(cmd.stream_name) if not stream: logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) return # We protect catching up with a linearizer in case the replication # connection reconnects under us. with await self._position_linearizer.queue(cmd.stream_name): # We're about to go and catch up with the stream, so remove from set # of connected streams. self._streams_connected.discard(cmd.stream_name) # We clear the pending batches for the stream as the fetching of the # missing updates below will fetch all rows in the batch. self._pending_batches.pop(cmd.stream_name, []) # Find where we previously streamed up to. current_token = stream.current_token() # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates # between then and now. missing_updates = cmd.token != current_token while missing_updates: ( updates, current_token, missing_updates, ) = await stream.get_updates_since(cmd.instance_name, current_token, cmd.token) # TODO: add some tests for this # Some streams return multiple rows with the same stream IDs, # which need to be processed in batches. for token, rows in _batch_updates(updates): await self.on_rdata( cmd.stream_name, cmd.instance_name, token, [stream.parse_row(row) for row in rows], ) # We've now caught up to position sent to us, notify handler. await self._replication_data_handler.on_position( cmd.stream_name, cmd.token) self._streams_connected.add(cmd.stream_name) async def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) self._notifier.notify_remote_server_up(cmd.data) # We relay to all other connections to ensure every instance gets the # notification. # # When configured to use redis we'll always only have one connection and # so this is a no-op (all instances will have already received the same # REMOTE_SERVER_UP command). # # For direct TCP connections this will relay to all other connections # connected to us. When on master this will correctly fan out to all # other direct TCP clients and on workers there'll only be the one # connection to master. # # (The logic here should also be sound if we have a mix of Redis and # direct TCP connections so long as there is only one traffic route # between two instances, but that is not currently supported). self.send_command(cmd, ignore_conn=conn) def new_connection(self, connection: AbstractConnection): """Called when we have a new connection. """ self._connections.append(connection) # If we are connected to replication as a client (rather than a server) # we need to reset the reconnection delay on the client factory (which # is used to do exponential back off when the connection drops). # # Ideally we would reset the delay when we've "fully established" the # connection (for some definition thereof) to stop us from tightlooping # on reconnection if something fails after this point and we drop the # connection. Unfortunately, we don't really have a better definition of # "fully established" than the connection being established. if self._factory: self._factory.resetDelay() # Tell the other end if we have any users currently syncing. currently_syncing = (self._presence_handler. get_currently_syncing_users_for_replication()) now = self._clock.time_msec() for user_id in currently_syncing: connection.send_command( UserSyncCommand(self._instance_id, user_id, True, now)) def lost_connection(self, connection: AbstractConnection): """Called when a connection is closed/lost. """ try: self._connections.remove(connection) except ValueError: pass def connected(self) -> bool: """Do we have any replication connections open? Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected. """ return bool(self._connections) def send_command(self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None): """Send a command to all connected connections. Args: cmd ignore_conn: If set don't send command to the given connection. Used when relaying commands from one connection to all others. """ if self._connections: for connection in self._connections: if connection == ignore_conn: continue try: connection.send_command(cmd) except Exception: # We probably want to catch some types of exceptions here # and log them as warnings (e.g. connection gone), but I # can't find what those exception types they would be. logger.exception( "Failed to write command %s to connection %s", cmd.NAME, connection, ) else: logger.warning("Dropping command as not connected: %r", cmd.NAME) def send_federation_ack(self, token: int): """Ack data for the federation stream. This allows the master to drop data stored purely in memory. """ self.send_command(FederationAckCommand(token)) def send_user_sync(self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int): """Poke the master that a user has started/stopped syncing. """ self.send_command( UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)) def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): """Poke the master to remove a pusher for a user """ cmd = RemovePusherCommand(app_id, push_key, user_id) self.send_command(cmd) def send_invalidate_cache(self, cache_func: Callable, keys: tuple): """Poke the master to invalidate a cache. """ cmd = InvalidateCacheCommand(cache_func.__name__, keys) self.send_command(cmd) def send_user_ip( self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str, last_seen: int, ): """Tell the master that the user made a request. """ cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) self.send_command(cmd) def send_remote_server_up(self, server: str): self.send_command(RemoteServerUpCommand(server)) def stream_update(self, stream_name: str, token: str, data: Any): """Called when a new update is available to stream to clients. We need to check if the client is interested in the stream or not """ self.send_command( RdataCommand(stream_name, self._instance_name, token, data))
class ReplicationCommandHandler: """Handles incoming commands from replication as well as sending commands back out to connections. """ def __init__(self, hs: "HomeServer"): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() self._store = hs.get_datastores().main self._notifier = hs.get_notifier() self._clock = hs.get_clock() self._instance_id = hs.get_instance_id() self._instance_name = hs.get_instance_name() # Additional Redis channel suffixes to subscribe to. self._channels_to_subscribe_to: List[str] = [] self._is_presence_writer = ( hs.get_instance_name() in hs.config.worker.writers.presence ) self._streams: Dict[str, Stream] = { stream.NAME: stream(hs) for stream in STREAMS_MAP.values() } # List of streams that this instance is the source of self._streams_to_replicate: List[Stream] = [] for stream in self._streams.values(): if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME: # All workers can write to the cache invalidation stream when # using redis. self._streams_to_replicate.append(stream) continue if isinstance(stream, (EventsStream, BackfillStream)): # Only add EventStream and BackfillStream as a source on the # instance in charge of event persistence. if hs.get_instance_name() in hs.config.worker.writers.events: self._streams_to_replicate.append(stream) continue if isinstance(stream, ToDeviceStream): # Only add ToDeviceStream as a source on instances in charge of # sending to device messages. if hs.get_instance_name() in hs.config.worker.writers.to_device: self._streams_to_replicate.append(stream) continue if isinstance(stream, TypingStream): # Only add TypingStream as a source on the instance in charge of # typing. if hs.get_instance_name() in hs.config.worker.writers.typing: self._streams_to_replicate.append(stream) continue if isinstance(stream, (AccountDataStream, TagAccountDataStream)): # Only add AccountDataStream and TagAccountDataStream as a source on the # instance in charge of account_data persistence. if hs.get_instance_name() in hs.config.worker.writers.account_data: self._streams_to_replicate.append(stream) continue if isinstance(stream, ReceiptsStream): # Only add ReceiptsStream as a source on the instance in charge of # receipts. if hs.get_instance_name() in hs.config.worker.writers.receipts: self._streams_to_replicate.append(stream) continue if isinstance(stream, (PresenceStream, PresenceFederationStream)): # Only add PresenceStream as a source on the instance in charge # of presence. if self._is_presence_writer: self._streams_to_replicate.append(stream) continue # Only add any other streams if we're on master. if hs.config.worker.worker_app is not None: continue if ( stream.NAME == FederationStream.NAME and hs.config.worker.send_federation ): # We only support federation stream if federation sending # has been disabled on the master. continue self._streams_to_replicate.append(stream) # Map of stream name to batched updates. See RdataCommand for info on # how batching works. self._pending_batches: Dict[str, List[Any]] = {} # The factory used to create connections. self._factory: Optional[ReconnectingClientFactory] = None # The currently connected connections. (The list of places we need to send # outgoing replication commands to.) self._connections: List[IReplicationConnection] = [] LaterGauge( "synapse_replication_tcp_resource_total_connections", "", [], lambda: len(self._connections), ) # When POSITION or RDATA commands arrive, we stick them in a queue and process # them in order in a separate background process. # the streams which are currently being processed by _unsafe_process_queue self._processing_streams: Set[str] = set() # for each stream, a queue of commands that are awaiting processing, and the # connection that they arrived on. self._command_queues_by_stream = { stream_name: _StreamCommandQueue() for stream_name in self._streams } # For each connection, the incoming stream names that have received a POSITION # from that connection. self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {} LaterGauge( "synapse_replication_tcp_command_queue", "Number of inbound RDATA/POSITION commands queued for processing", ["stream_name"], lambda: { (stream_name,): len(queue) for stream_name, queue in self._command_queues_by_stream.items() }, ) self._is_master = hs.config.worker.worker_app is None self._federation_sender = None if self._is_master and not hs.config.worker.send_federation: self._federation_sender = hs.get_federation_sender() self._server_notices_sender = None if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() if hs.config.redis.redis_enabled: # If we're using Redis, it's the background worker that should # receive USER_IP commands and store the relevant client IPs. self._should_insert_client_ips = hs.config.worker.run_background_tasks else: # If we're NOT using Redis, this must be handled by the master self._should_insert_client_ips = hs.get_instance_name() == "master" if self._is_master or self._should_insert_client_ips: self.subscribe_to_channel("USER_IP") def subscribe_to_channel(self, channel_name: str) -> None: """ Indicates that we wish to subscribe to a Redis channel by name. (The name will later be prefixed with the server name; i.e. subscribing to the 'ABC' channel actually subscribes to 'example.com/ABC' Redis-side.) Raises: - If replication has already started, then it's too late to subscribe to new channels. """ if self._factory is not None: # We don't allow subscribing after the fact to avoid the chance # of missing an important message because we didn't subscribe in time. raise RuntimeError( "Cannot subscribe to more channels after replication started." ) if channel_name not in self._channels_to_subscribe_to: self._channels_to_subscribe_to.append(channel_name) def _add_command_to_stream_queue( self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand] ) -> None: """Queue the given received command for processing Adds the given command to the per-stream queue, and processes the queue if necessary """ stream_name = cmd.stream_name queue = self._command_queues_by_stream.get(stream_name) if queue is None: logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name) return queue.append((cmd, conn)) # if we're already processing this stream, there's nothing more to do: # the new entry on the queue will get picked up in due course if stream_name in self._processing_streams: return # fire off a background process to start processing the queue. run_as_background_process( "process-replication-data", self._unsafe_process_queue, stream_name ) async def _unsafe_process_queue(self, stream_name: str) -> None: """Processes the command queue for the given stream, until it is empty Does not check if there is already a thread processing the queue, hence "unsafe" """ assert stream_name not in self._processing_streams self._processing_streams.add(stream_name) try: queue = self._command_queues_by_stream.get(stream_name) while queue: cmd, conn = queue.popleft() try: await self._process_command(cmd, conn, stream_name) except Exception: logger.exception("Failed to handle command %s", cmd) finally: self._processing_streams.discard(stream_name) async def _process_command( self, cmd: Union[PositionCommand, RdataCommand], conn: IReplicationConnection, stream_name: str, ) -> None: if isinstance(cmd, PositionCommand): await self._process_position(stream_name, conn, cmd) elif isinstance(cmd, RdataCommand): await self._process_rdata(stream_name, conn, cmd) else: # This shouldn't be possible raise Exception("Unrecognised command %s in stream queue", cmd.NAME) def start_replication(self, hs: "HomeServer") -> None: """Helper method to start replication.""" if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, ) # First let's ensure that we have a ReplicationStreamer started. hs.get_replication_streamer() # We need two connections to redis, one for the subscription stream and # one to send commands to (as you can't send further redis commands to a # connection after SUBSCRIBE is called). # First create the connection for sending commands. outbound_redis_connection = hs.get_outbound_redis_connection() # Now create the factory/connection for the subscription stream. self._factory = RedisDirectTcpReplicationClientFactory( hs, outbound_redis_connection, channel_names=self._channels_to_subscribe_to, ) hs.get_reactor().connectTCP( hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, timeout=30, bindAddress=None, ) else: client_name = hs.get_instance_name() self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) host = hs.config.worker.worker_replication_host port = hs.config.worker.worker_replication_port hs.get_reactor().connectTCP( host, port, self._factory, timeout=30, bindAddress=None, ) def get_streams(self) -> Dict[str, Stream]: """Get a map from stream name to all streams.""" return self._streams def get_streams_to_replicate(self) -> List[Stream]: """Get a list of streams that this instances replicates.""" return self._streams_to_replicate def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand) -> None: self.send_positions_to_connection(conn) def send_positions_to_connection(self, conn: IReplicationConnection) -> None: """Send current position of all streams this process is source of to the connection. """ # We respond with current position of all streams this instance # replicates. for stream in self.get_streams_to_replicate(): # Note that we use the current token as the prev token here (rather # than stream.last_token), as we can't be sure that there have been # no rows written between last token and the current token (since we # might be racing with the replication sending bg process). current_token = stream.current_token(self._instance_name) self.send_command( PositionCommand( stream.NAME, self._instance_name, current_token, current_token, ) ) def on_USER_SYNC( self, conn: IReplicationConnection, cmd: UserSyncCommand ) -> Optional[Awaitable[None]]: user_sync_counter.inc() if self._is_presence_writer: return self._presence_handler.update_external_syncs_row( cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) else: return None def on_CLEAR_USER_SYNC( self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand ) -> Optional[Awaitable[None]]: if self._is_presence_writer: return self._presence_handler.update_external_syncs_clear(cmd.instance_id) else: return None def on_FEDERATION_ACK( self, conn: IReplicationConnection, cmd: FederationAckCommand ) -> None: federation_ack_counter.inc() if self._federation_sender: self._federation_sender.federation_ack(cmd.instance_name, cmd.token) def on_USER_IP( self, conn: IReplicationConnection, cmd: UserIpCommand ) -> Optional[Awaitable[None]]: user_ip_cache_counter.inc() if self._is_master or self._should_insert_client_ips: # We make a point of only returning an awaitable if there's actually # something to do; on_USER_IP is not an async function, but # _handle_user_ip is. # If on_USER_IP returns an awaitable, it gets scheduled as a # background process (see `BaseReplicationStreamProtocol.handle_command`). return self._handle_user_ip(cmd) else: # Returning None when this process definitely has nothing to do # reduces the overhead of handling the USER_IP command, which is # currently broadcast to all workers regardless of utility. return None async def _handle_user_ip(self, cmd: UserIpCommand) -> None: """ Handles a User IP, branching depending on whether we are the main process and/or the background worker. """ if self._is_master: assert self._server_notices_sender is not None await self._server_notices_sender.on_user_ip(cmd.user_id) if self._should_insert_client_ips: await self._store.insert_client_ip( cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id, cmd.last_seen, ) def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None: if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes return stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() # We put the received command into a queue here for two reasons: # 1. so we don't try and concurrently handle multiple rows for the # same stream, and # 2. so we don't race with getting a POSITION command and fetching # missing RDATA. self._add_command_to_stream_queue(conn, cmd) async def _process_rdata( self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand ) -> None: """Process an RDATA command Called after the command has been popped off the queue of inbound commands """ try: row = STREAMS_MAP[stream_name].parse_row(cmd.row) except Exception as e: raise Exception( "Failed to parse RDATA: %r %r" % (stream_name, cmd.row) ) from e # make sure that we've processed a POSITION for this stream *on this # connection*. (A POSITION on another connection is no good, as there # is no guarantee that we have seen all the intermediate updates.) sbc = self._streams_by_connection.get(conn) if not sbc or stream_name not in sbc: # Let's drop the row for now, on the assumption we'll receive a # `POSITION` soon and we'll catch up correctly then. logger.debug( "Discarding RDATA for unconnected stream %s -> %s", stream_name, cmd.token, ) return if cmd.token is None: # I.e. this is part of a batch of updates for this stream (in # which case batch until we get an update for the stream with a non # None token). self._pending_batches.setdefault(stream_name, []).append(row) return # Check if this is the last of a batch of updates rows = self._pending_batches.pop(stream_name, []) rows.append(row) stream = self._streams[stream_name] # Find where we previously streamed up to. current_token = stream.current_token(cmd.instance_name) # Discard this data if this token is earlier than the current # position. Note that streams can be reset (in which case you # expect an earlier token), but that must be preceded by a # POSITION command. if cmd.token <= current_token: logger.debug( "Discarding RDATA from stream %s at position %s before previous position %s", stream_name, cmd.token, current_token, ) else: await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list ) -> None: """Called to handle a batch of replication data with a given stream token. Args: stream_name: name of the replication stream for this batch of rows instance_name: the instance that wrote the rows. token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token) await self._replication_data_handler.on_rdata( stream_name, instance_name, token, rows ) def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand) -> None: if cmd.instance_name == self._instance_name: # Ignore POSITION that are just our own echoes return logger.debug("Handling '%s %s'", cmd.NAME, cmd.to_line()) self._add_command_to_stream_queue(conn, cmd) async def _process_position( self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand ) -> None: """Process a POSITION command Called after the command has been popped off the queue of inbound commands """ stream = self._streams[stream_name] # We're about to go and catch up with the stream, so remove from set # of connected streams. for streams in self._streams_by_connection.values(): streams.discard(stream_name) # We clear the pending batches for the stream as the fetching of the # missing updates below will fetch all rows in the batch. self._pending_batches.pop(stream_name, []) # Find where we previously streamed up to. current_token = stream.current_token(cmd.instance_name) # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates # between then and now. missing_updates = cmd.prev_token != current_token while missing_updates: # Note: There may very well not be any new updates, but we check to # make sure. This can particularly happen for the event stream where # event persisters continuously send `POSITION`. See `resource.py` # for why this can happen. logger.info( "Fetching replication rows for '%s' between %i and %i", stream_name, current_token, cmd.new_token, ) (updates, current_token, missing_updates) = await stream.get_updates_since( cmd.instance_name, current_token, cmd.new_token ) # TODO: add some tests for this # Some streams return multiple rows with the same stream IDs, # which need to be processed in batches. for token, rows in _batch_updates(updates): await self.on_rdata( stream_name, cmd.instance_name, token, [stream.parse_row(row) for row in rows], ) logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token) # We've now caught up to position sent to us, notify handler. await self._replication_data_handler.on_position( cmd.stream_name, cmd.instance_name, cmd.new_token ) self._streams_by_connection.setdefault(conn, set()).add(stream_name) def on_REMOTE_SERVER_UP( self, conn: IReplicationConnection, cmd: RemoteServerUpCommand ) -> None: """Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) self._notifier.notify_remote_server_up(cmd.data) # We relay to all other connections to ensure every instance gets the # notification. # # When configured to use redis we'll always only have one connection and # so this is a no-op (all instances will have already received the same # REMOTE_SERVER_UP command). # # For direct TCP connections this will relay to all other connections # connected to us. When on master this will correctly fan out to all # other direct TCP clients and on workers there'll only be the one # connection to master. # # (The logic here should also be sound if we have a mix of Redis and # direct TCP connections so long as there is only one traffic route # between two instances, but that is not currently supported). self.send_command(cmd, ignore_conn=conn) def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" self._connections.append(connection) # If we are connected to replication as a client (rather than a server) # we need to reset the reconnection delay on the client factory (which # is used to do exponential back off when the connection drops). # # Ideally we would reset the delay when we've "fully established" the # connection (for some definition thereof) to stop us from tightlooping # on reconnection if something fails after this point and we drop the # connection. Unfortunately, we don't really have a better definition of # "fully established" than the connection being established. if self._factory: self._factory.resetDelay() # Tell the other end if we have any users currently syncing. currently_syncing = ( self._presence_handler.get_currently_syncing_users_for_replication() ) now = self._clock.time_msec() for user_id in currently_syncing: connection.send_command( UserSyncCommand(self._instance_id, user_id, True, now) ) def lost_connection(self, connection: IReplicationConnection) -> None: """Called when a connection is closed/lost.""" # we no longer need _streams_by_connection for this connection. streams = self._streams_by_connection.pop(connection, None) if streams: logger.info( "Lost replication connection; streams now disconnected: %s", streams ) try: self._connections.remove(connection) except ValueError: pass def connected(self) -> bool: """Do we have any replication connections open? Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected. """ return bool(self._connections) def send_command( self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None ) -> None: """Send a command to all connected connections. Args: cmd ignore_conn: If set don't send command to the given connection. Used when relaying commands from one connection to all others. """ if self._connections: for connection in self._connections: if connection == ignore_conn: continue try: connection.send_command(cmd) except Exception: # We probably want to catch some types of exceptions here # and log them as warnings (e.g. connection gone), but I # can't find what those exception types they would be. logger.exception( "Failed to write command %s to connection %s", cmd.NAME, connection, ) else: logger.warning("Dropping command as not connected: %r", cmd.NAME) def send_federation_ack(self, token: int) -> None: """Ack data for the federation stream. This allows the master to drop data stored purely in memory. """ self.send_command(FederationAckCommand(self._instance_name, token)) def send_user_sync( self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int ) -> None: """Poke the master that a user has started/stopped syncing.""" self.send_command( UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) ) def send_user_ip( self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: Optional[str], last_seen: int, ) -> None: """Tell the master that the user made a request.""" cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) self.send_command(cmd) def send_remote_server_up(self, server: str) -> None: self.send_command(RemoteServerUpCommand(server)) def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None: """Called when a new update is available to stream to Redis subscribers. We need to check if the client is interested in the stream or not """ self.send_command(RdataCommand(stream_name, self._instance_name, token, data))