def test_replication_progress_pre_calculate(): subprocess.call("zfs destroy -r data/src", shell=True) subprocess.call("zfs destroy -r data/dst", shell=True) subprocess.check_call("zfs create data/src", shell=True) subprocess.check_call("zfs create data/src/alice", shell=True) subprocess.check_call("zfs create data/src/bob", shell=True) subprocess.check_call("zfs create data/src/charlie", shell=True) subprocess.check_call("zfs snapshot -r data/src@2018-10-01_01-00", shell=True) subprocess.check_call("zfs create data/dst", shell=True) subprocess.check_call( "zfs send -R data/src@2018-10-01_01-00 | zfs recv -s -F data/dst", shell=True) subprocess.check_call("zfs create data/src/dave", shell=True) subprocess.check_call("zfs snapshot -r data/src@2018-10-01_02-00", shell=True) definition = yaml.safe_load( textwrap.dedent("""\ timezone: "UTC" replication-tasks: src: direction: push transport: type: local source-dataset: data/src target-dataset: data/dst recursive: true also-include-naming-schema: - "%Y-%m-%d_%H-%M" auto: false retention-policy: none retries: 1 """)) definition = Definition.from_data(definition) zettarepl = create_zettarepl(definition) zettarepl._spawn_replication_tasks( select_by_class(ReplicationTask, definition.tasks)) wait_replication_tasks_to_complete(zettarepl) calls = [ call for call in zettarepl.observer.call_args_list if call[0][0].__class__ != ReplicationTaskDataProgress ] result = [ ReplicationTaskStart("src"), ReplicationTaskSnapshotStart("src", "data/src", "2018-10-01_02-00", 0, 5), ReplicationTaskSnapshotSuccess("src", "data/src", "2018-10-01_02-00", 1, 5), ReplicationTaskSnapshotStart("src", "data/src/alice", "2018-10-01_02-00", 1, 5), ReplicationTaskSnapshotSuccess("src", "data/src/alice", "2018-10-01_02-00", 2, 5), ReplicationTaskSnapshotStart("src", "data/src/bob", "2018-10-01_02-00", 2, 5), ReplicationTaskSnapshotSuccess("src", "data/src/bob", "2018-10-01_02-00", 3, 5), ReplicationTaskSnapshotStart("src", "data/src/charlie", "2018-10-01_02-00", 3, 5), ReplicationTaskSnapshotSuccess("src", "data/src/charlie", "2018-10-01_02-00", 4, 5), ReplicationTaskSnapshotStart("src", "data/src/dave", "2018-10-01_02-00", 4, 5), ReplicationTaskSnapshotSuccess("src", "data/src/dave", "2018-10-01_02-00", 5, 5), ReplicationTaskSuccess("src"), ] for i, message in enumerate(result): call = calls[i] assert call[0][0].__class__ == message.__class__ d1 = call[0][0].__dict__ d2 = message.__dict__ assert d1 == d2
def test_replication_progress(transport): subprocess.call("zfs destroy -r data/src", shell=True) subprocess.call("zfs destroy -r data/dst", shell=True) subprocess.check_call("zfs create data/src", shell=True) subprocess.check_call("zfs create data/src/src1", shell=True) subprocess.check_call("zfs snapshot data/src/src1@2018-10-01_01-00", shell=True) subprocess.check_call( "dd if=/dev/urandom of=/mnt/data/src/src1/blob bs=1M count=1", shell=True) subprocess.check_call("zfs snapshot data/src/src1@2018-10-01_02-00", shell=True) subprocess.check_call("rm /mnt/data/src/src1/blob", shell=True) subprocess.check_call("zfs snapshot data/src/src1@2018-10-01_03-00", shell=True) subprocess.check_call("zfs create data/src/src2", shell=True) subprocess.check_call("zfs snapshot data/src/src2@2018-10-01_01-00", shell=True) subprocess.check_call("zfs snapshot data/src/src2@2018-10-01_02-00", shell=True) subprocess.check_call("zfs snapshot data/src/src2@2018-10-01_03-00", shell=True) subprocess.check_call("zfs snapshot data/src/src2@2018-10-01_04-00", shell=True) definition = yaml.safe_load( textwrap.dedent("""\ timezone: "UTC" replication-tasks: src: direction: push source-dataset: - data/src/src1 - data/src/src2 target-dataset: data/dst recursive: true also-include-naming-schema: - "%Y-%m-%d_%H-%M" auto: false retention-policy: none retries: 1 """)) definition["replication-tasks"]["src"]["transport"] = transport if transport["type"] == "ssh": definition["replication-tasks"]["src"]["speed-limit"] = 10240 * 9 definition = Definition.from_data(definition) zettarepl = create_zettarepl(definition) zettarepl._spawn_replication_tasks( select_by_class(ReplicationTask, definition.tasks)) wait_replication_tasks_to_complete(zettarepl) calls = [ call for call in zettarepl.observer.call_args_list if call[0][0].__class__ != ReplicationTaskDataProgress ] result = [ ReplicationTaskStart("src"), ReplicationTaskSnapshotStart("src", "data/src/src1", "2018-10-01_01-00", 0, 3), ReplicationTaskSnapshotSuccess("src", "data/src/src1", "2018-10-01_01-00", 1, 3), ReplicationTaskSnapshotStart("src", "data/src/src1", "2018-10-01_02-00", 1, 3), ReplicationTaskSnapshotSuccess("src", "data/src/src1", "2018-10-01_02-00", 2, 3), ReplicationTaskSnapshotStart("src", "data/src/src1", "2018-10-01_03-00", 2, 3), ReplicationTaskSnapshotSuccess("src", "data/src/src1", "2018-10-01_03-00", 3, 3), ReplicationTaskSnapshotStart("src", "data/src/src2", "2018-10-01_01-00", 3, 7), ReplicationTaskSnapshotSuccess("src", "data/src/src2", "2018-10-01_01-00", 4, 7), ReplicationTaskSnapshotStart("src", "data/src/src2", "2018-10-01_02-00", 4, 7), ReplicationTaskSnapshotSuccess("src", "data/src/src2", "2018-10-01_02-00", 5, 7), ReplicationTaskSnapshotStart("src", "data/src/src2", "2018-10-01_03-00", 5, 7), ReplicationTaskSnapshotSuccess("src", "data/src/src2", "2018-10-01_03-00", 6, 7), ReplicationTaskSnapshotStart("src", "data/src/src2", "2018-10-01_04-00", 6, 7), ReplicationTaskSnapshotSuccess("src", "data/src/src2", "2018-10-01_04-00", 7, 7), ReplicationTaskSuccess("src"), ] if transport["type"] == "ssh": result.insert( 4, ReplicationTaskSnapshotProgress( "src", "data/src/src1", "2018-10-01_02-00", 1, 3, 10240 * 9 * 10, # We poll for progress every 10 seconds so # we would have transferred 10x speed limit 2162784 # Empirical value )) for i, message in enumerate(result): call = calls[i] assert call[0][0].__class__ == message.__class__, calls d1 = call[0][0].__dict__ d2 = message.__dict__ if isinstance(message, ReplicationTaskSnapshotProgress): bytes_sent_1 = d1.pop("bytes_sent") bytes_total_1 = d1.pop("bytes_total") bytes_sent_2 = d2.pop("bytes_sent") bytes_total_2 = d2.pop("bytes_total") assert 0.8 <= bytes_sent_1 / bytes_sent_2 <= 1.2 assert 0.8 <= bytes_total_1 / bytes_total_2 <= 1.2 assert d1 == d2
def run_replication_tasks(local_shell: LocalShell, transport: Transport, remote_shell: Shell, replication_tasks: [ReplicationTask], observer=None): replication_tasks_parts = calculate_replication_tasks_parts(replication_tasks) started_replication_tasks_ids = set() failed_replication_tasks_ids = set() replication_tasks_parts_left = { replication_task.id: len([1 for another_replication_task, source_dataset in replication_tasks_parts if another_replication_task == replication_task]) for replication_task in replication_tasks } for replication_task, source_dataset in replication_tasks_parts: if replication_task.id in failed_replication_tasks_ids: continue local_context = ReplicationContext(None, local_shell) remote_context = ReplicationContext(transport, remote_shell) if replication_task.direction == ReplicationDirection.PUSH: src_context = local_context dst_context = remote_context elif replication_task.direction == ReplicationDirection.PULL: src_context = remote_context dst_context = local_context else: raise ValueError(f"Invalid replication direction: {replication_task.direction!r}") if replication_task.id not in started_replication_tasks_ids: notify(observer, ReplicationTaskStart(replication_task.id)) started_replication_tasks_ids.add(replication_task.id) recoverable_error = None recoverable_sleep = 1 for i in range(replication_task.retries): if recoverable_error is not None: logger.info("After recoverable error sleeping for %d seconds", recoverable_sleep) time.sleep(recoverable_sleep) recoverable_sleep = min(recoverable_sleep * 2, 60) else: recoverable_sleep = 1 try: try: run_replication_task_part(replication_task, source_dataset, src_context, dst_context, observer) except socket.timeout: raise RecoverableReplicationError("Network connection timeout") from None except paramiko.ssh_exception.NoValidConnectionsError as e: raise RecoverableReplicationError(str(e).replace("[Errno None] ", "")) from None except (IOError, OSError) as e: raise RecoverableReplicationError(str(e)) from None replication_tasks_parts_left[replication_task.id] -= 1 if replication_tasks_parts_left[replication_task.id] == 0: notify(observer, ReplicationTaskSuccess(replication_task.id)) break except RecoverableReplicationError as e: logger.warning("For task %r at attempt %d recoverable replication error %r", replication_task.id, i + 1, e) recoverable_error = e except ReplicationError as e: logger.error("For task %r non-recoverable replication error %r", replication_task.id, e) notify(observer, ReplicationTaskError(replication_task.id, str(e))) failed_replication_tasks_ids.add(replication_task.id) break except Exception as e: logger.error("For task %r unhandled replication error %r", replication_task.id, e, exc_info=True) notify(observer, ReplicationTaskError(replication_task.id, str(e))) failed_replication_tasks_ids.add(replication_task.id) break else: logger.error("Failed replication task %r after %d retries", replication_task.id, replication_task.retries) notify(observer, ReplicationTaskError(replication_task.id, str(recoverable_error))) failed_replication_tasks_ids.add(replication_task.id)
def test_replication_progress_resume(): subprocess.call("zfs destroy -r data/src", shell=True) subprocess.call("zfs destroy -r data/dst", shell=True) subprocess.check_call("zfs create data/src", shell=True) subprocess.check_call("zfs snapshot data/src@2018-10-01_01-00", shell=True) subprocess.check_call( "dd if=/dev/urandom of=/mnt/data/src/blob bs=1M count=1", shell=True) subprocess.check_call("zfs snapshot data/src@2018-10-01_02-00", shell=True) subprocess.check_call( "dd if=/dev/urandom of=/mnt/data/src/blob bs=1M count=1", shell=True) subprocess.check_call("zfs snapshot data/src@2018-10-01_03-00", shell=True) subprocess.check_call( "dd if=/dev/urandom of=/mnt/data/src/blob bs=1M count=1", shell=True) subprocess.check_call("zfs snapshot data/src@2018-10-01_04-00", shell=True) subprocess.check_call("zfs create data/dst", shell=True) subprocess.check_call( "zfs send data/src@2018-10-01_01-00 | zfs recv -s -F data/dst", shell=True) subprocess.check_call( "(zfs send -i data/src@2018-10-01_01-00 data/src@2018-10-01_02-00 | " " throttle -b 102400 | zfs recv -s -F data/dst) & " "sleep 1; killall zfs", shell=True) assert "receive_resume_token\t1-" in subprocess.check_output( "zfs get -H receive_resume_token data/dst", shell=True, encoding="utf-8") definition = yaml.safe_load( textwrap.dedent("""\ timezone: "UTC" replication-tasks: src: direction: push transport: type: local source-dataset: data/src target-dataset: data/dst recursive: true also-include-naming-schema: - "%Y-%m-%d_%H-%M" auto: false retention-policy: none retries: 1 """)) definition = Definition.from_data(definition) zettarepl = create_zettarepl(definition) zettarepl._spawn_replication_tasks( select_by_class(ReplicationTask, definition.tasks)) wait_replication_tasks_to_complete(zettarepl) calls = [ call for call in zettarepl.observer.call_args_list if call[0][0].__class__ != ReplicationTaskDataProgress ] result = [ ReplicationTaskStart("src"), ReplicationTaskSnapshotStart("src", "data/src", "2018-10-01_02-00", 0, 3), ReplicationTaskSnapshotSuccess("src", "data/src", "2018-10-01_02-00", 1, 3), ReplicationTaskSnapshotStart("src", "data/src", "2018-10-01_03-00", 1, 3), ReplicationTaskSnapshotSuccess("src", "data/src", "2018-10-01_03-00", 2, 3), ReplicationTaskSnapshotStart("src", "data/src", "2018-10-01_04-00", 2, 3), ReplicationTaskSnapshotSuccess("src", "data/src", "2018-10-01_04-00", 3, 3), ReplicationTaskSuccess("src"), ] for i, message in enumerate(result): call = calls[i] assert call[0][0].__class__ == message.__class__ d1 = call[0][0].__dict__ d2 = message.__dict__ assert d1 == d2
def run_replication_tasks(local_shell: LocalShell, transport: Transport, remote_shell: Shell, replication_tasks: [ReplicationTask], observer=None): contexts = defaultdict(GlobalReplicationContext) replication_tasks_parts = calculate_replication_tasks_parts( replication_tasks) started_replication_tasks_ids = set() failed_replication_tasks_ids = set() replication_tasks_parts_left = { replication_task.id: len([ 1 for another_replication_task, source_dataset in replication_tasks_parts if another_replication_task == replication_task ]) for replication_task in replication_tasks } for replication_task, source_dataset in replication_tasks_parts: if replication_task.id in failed_replication_tasks_ids: continue local_context = ReplicationContext(contexts[replication_task], None, local_shell) remote_context = ReplicationContext(contexts[replication_task], transport, remote_shell) if replication_task.direction == ReplicationDirection.PUSH: src_context = local_context dst_context = remote_context elif replication_task.direction == ReplicationDirection.PULL: src_context = remote_context dst_context = local_context else: raise ValueError( f"Invalid replication direction: {replication_task.direction!r}" ) if replication_task.id not in started_replication_tasks_ids: notify(observer, ReplicationTaskStart(replication_task.id)) started_replication_tasks_ids.add(replication_task.id) recoverable_error = None recoverable_sleep = 1 for i in range(replication_task.retries): if recoverable_error is not None: logger.info("After recoverable error sleeping for %d seconds", recoverable_sleep) time.sleep(recoverable_sleep) recoverable_sleep = min(recoverable_sleep * 2, 60) else: recoverable_sleep = 1 try: try: run_replication_task_part(replication_task, source_dataset, src_context, dst_context, observer) except socket.timeout: raise RecoverableReplicationError( "Network connection timeout") from None except paramiko.ssh_exception.NoValidConnectionsError as e: raise RecoverableReplicationError( str(e).replace("[Errno None] ", "")) from None except paramiko.ssh_exception.SSHException as e: if isinstance( e, (paramiko.ssh_exception.AuthenticationException, paramiko.ssh_exception.BadHostKeyException, paramiko.ssh_exception.ProxyCommandFailure, paramiko.ssh_exception.ConfigParseError)): raise ReplicationError( str(e).replace("[Errno None] ", "")) from None else: # It might be an SSH error that leaves paramiko connection in an invalid state # Let's reset remote shell just in case remote_shell.close() raise RecoverableReplicationError( str(e).replace("[Errno None] ", "")) from None except ExecException as e: if e.returncode == 128 + signal.SIGPIPE: for warning in warnings_from_zfs_success(e.stdout): contexts[replication_task].add_warning(warning) raise RecoverableReplicationError( broken_pipe_error(e.stdout)) else: raise except (IOError, OSError) as e: raise RecoverableReplicationError(str(e)) from None replication_tasks_parts_left[replication_task.id] -= 1 if replication_tasks_parts_left[replication_task.id] == 0: notify( observer, ReplicationTaskSuccess( replication_task.id, contexts[replication_task].warnings)) break except RecoverableReplicationError as e: logger.warning( "For task %r at attempt %d recoverable replication error %r", replication_task.id, i + 1, e) recoverable_error = e except ReplicationError as e: logger.error( "For task %r non-recoverable replication error %r", replication_task.id, e) notify(observer, ReplicationTaskError(replication_task.id, str(e))) failed_replication_tasks_ids.add(replication_task.id) break except Exception as e: logger.error("For task %r unhandled replication error %r", replication_task.id, e, exc_info=True) notify(observer, ReplicationTaskError(replication_task.id, str(e))) failed_replication_tasks_ids.add(replication_task.id) break else: logger.error("Failed replication task %r after %d retries", replication_task.id, replication_task.retries) notify( observer, ReplicationTaskError(replication_task.id, str(recoverable_error))) failed_replication_tasks_ids.add(replication_task.id)