def test_successful_transaction(): messages = list(make_batch_messages(batch_identifier, [ {'begin_operation': begin}, {'mutation_operation': mutation}, {'commit_operation': commit}, ])) state = None state = reserialize(validate_transaction_state(state, 0, messages[0])) assert get_oneof_value(state, 'state') == ( InTransaction( publisher=messages[0].header.publisher, batch_identifier=batch_identifier ) ) state = reserialize(validate_transaction_state(state, 1, messages[1])) assert get_oneof_value(state, 'state') == ( InTransaction( publisher=messages[1].header.publisher, batch_identifier=batch_identifier ) ) state = reserialize(validate_transaction_state(state, 2, messages[2])) assert get_oneof_value(state, 'state') == ( Committed( publisher=messages[2].header.publisher, batch_identifier=batch_identifier ) )
def test_publisher_failure(): messages = [] publisher = Publisher(messages.extend) with pytest.raises(NotImplementedError): with publisher.batch(batch_identifier, begin): raise NotImplementedError published_messages = map(reserialize, messages) assert get_oneof_value( get_oneof_value(published_messages[0], 'operation'), 'operation' ) == begin assert get_oneof_value( get_oneof_value(published_messages[1], 'operation'), 'operation' ) == rollback # Ensure it actually generates valid data. state = None for offset, message in enumerate(published_messages): state = reserialize(validate_state(state, offset, message)) for i, message in enumerate(published_messages): assert message.header.publisher == publisher.id assert message.header.sequence == i # Write another message to ensure that the publisher can continue to be used. assert len(messages) == 2 publisher.publish() assert len(messages) == 3 assert messages[2].header.sequence == 2
def __call__(self, state, offset, message): if state is not None: state = get_oneof_value(state, "state") operation = get_oneof_value(get_oneof_value(message, "operation"), "operation") state_cls = type(state) if state is not None else None try: receivers = self.receivers[state_cls] except KeyError: raise InvalidEventError("Cannot receive events in state: {0!r}".format(state)) try: receiver = receivers[type(operation)] except KeyError: raise InvalidEventError( "Cannot receive {0!r} while in state: {1!r}".format(operation, state), expected=set(receivers.keys()) ) return self.message(**receiver(state, offset, message))
def test_publisher(): messages = [] publisher = Publisher(messages.extend) with publisher.batch(batch_identifier, begin) as publish: publish(mutation) published_messages = map(reserialize, messages) assert get_oneof_value( get_oneof_value(published_messages[0], 'operation'), 'operation' ) == begin assert get_oneof_value( get_oneof_value(published_messages[1], 'operation'), 'operation' ) == mutation assert get_oneof_value( get_oneof_value(published_messages[2], 'operation'), 'operation' ) == commit for i, message in enumerate(published_messages): assert message.header.publisher == publisher.id assert message.header.sequence == i # Ensure it actually generates valid data. state = None for offset, message in enumerate(published_messages): state = reserialize(validate_state(state, offset, message))
def run(self, loader, stream): state = self.get_state() if state is None: logger.info('Bootstrapping new replication target with %s...', loader) with loader.fetch() as (bootstrap_state, loaders): for table, records in loaders: logger.info('Loading records from %s...', table.name) self.load(table, records) state = State(bootstrap_state=bootstrap_state) self.commit(state) logger.debug( 'Successfully bootstrapped from %s using snapshot: %s', uuid.UUID(bytes=bootstrap_state.node).hex, FormattedSnapshot(bootstrap_state.snapshot), ) logger.info('Starting to consume from %s...', stream) for state, offset, message in stream.consume(state): operation = get_oneof_value(message.batch_operation, 'operation') if isinstance(operation, BeginOperation): logger.debug( 'Beginning %s (%s to %s)...', FormattedBatchIdentifier(message.batch_operation.batch_identifier), FormattedSnapshot(operation.start.snapshot), FormattedSnapshot(operation.end.snapshot), ) elif isinstance(operation, MutationOperation): # Skip any messages that were part of the bootstrap snapshot. if state.HasField('bootstrap_state') and txid_visible_in_snapshot(operation.transaction, state.bootstrap_state.snapshot): logger.debug('Skipping operation that was visible in bootstrap snapshot.') else: self.apply(state, operation) elif isinstance(operation, CommitOperation): logger.info( 'Committing %s.', FormattedBatchIdentifier(message.batch_operation.batch_identifier), ) self.commit(state) elif isinstance(operation, RollbackOperation): logger.info( 'Rolling back %s.', FormattedBatchIdentifier(message.batch_operation.batch_identifier), ) self.rollback(state) else: raise AssertionError('Received unexpected operation!')
def validate_bootstrap_state(state, offset, message): if state is None: return None # can be removed (considered "streaming") when xmin (first still running # transaction) of start tick on same node as bootstrap is greater than or # equal to bootstrap xmax (first unassigned txid). (this transaction # doesn't have to succeed, it just means that the ticker has advanced past # the end of the bootstrap snapshot) this could be a little bit more reliable # by using commit operation but then those would have to also include the # start/stop payloads? (not sure if that's worth it. this does also mean that # a batch must be started, even if failed, before mode switch can occur # which is a little weird) # TODO: this needs to also check correct node, obv operation = get_oneof_value(message.batch_operation, 'operation') if isinstance(operation, BeginOperation) and operation.start.snapshot.min > state.snapshot.max: logger.info('Caught up with replication stream!') state = None return state
def is_start_batch_operation(message): """ Is the ``message`` the start operation of a batch of mutations. """ value = get_oneof_value(message.batch_operation, 'operation') return type(value) in TRANSACTION_START_EVENT_TYPES