示例#1
0
def test_replication_progress_pre_calculate():
    subprocess.call("zfs destroy -r data/src", shell=True)
    subprocess.call("zfs destroy -r data/dst", shell=True)

    subprocess.check_call("zfs create data/src", shell=True)
    subprocess.check_call("zfs create data/src/alice", shell=True)
    subprocess.check_call("zfs create data/src/bob", shell=True)
    subprocess.check_call("zfs create data/src/charlie", shell=True)
    subprocess.check_call("zfs snapshot -r data/src@2018-10-01_01-00",
                          shell=True)

    subprocess.check_call("zfs create data/dst", shell=True)
    subprocess.check_call(
        "zfs send -R data/src@2018-10-01_01-00 | zfs recv -s -F data/dst",
        shell=True)

    subprocess.check_call("zfs create data/src/dave", shell=True)
    subprocess.check_call("zfs snapshot -r data/src@2018-10-01_02-00",
                          shell=True)

    definition = yaml.safe_load(
        textwrap.dedent("""\
        timezone: "UTC"

        replication-tasks:
          src:
            direction: push
            transport:
              type: local
            source-dataset: data/src
            target-dataset: data/dst
            recursive: true
            also-include-naming-schema:
            - "%Y-%m-%d_%H-%M"
            auto: false
            retention-policy: none
            retries: 1
    """))

    definition = Definition.from_data(definition)
    zettarepl = create_zettarepl(definition)
    zettarepl._spawn_replication_tasks(
        select_by_class(ReplicationTask, definition.tasks))
    wait_replication_tasks_to_complete(zettarepl)

    calls = [
        call for call in zettarepl.observer.call_args_list
        if call[0][0].__class__ != ReplicationTaskDataProgress
    ]

    result = [
        ReplicationTaskStart("src"),
        ReplicationTaskSnapshotStart("src", "data/src", "2018-10-01_02-00", 0,
                                     5),
        ReplicationTaskSnapshotSuccess("src", "data/src", "2018-10-01_02-00",
                                       1, 5),
        ReplicationTaskSnapshotStart("src", "data/src/alice",
                                     "2018-10-01_02-00", 1, 5),
        ReplicationTaskSnapshotSuccess("src", "data/src/alice",
                                       "2018-10-01_02-00", 2, 5),
        ReplicationTaskSnapshotStart("src", "data/src/bob", "2018-10-01_02-00",
                                     2, 5),
        ReplicationTaskSnapshotSuccess("src", "data/src/bob",
                                       "2018-10-01_02-00", 3, 5),
        ReplicationTaskSnapshotStart("src", "data/src/charlie",
                                     "2018-10-01_02-00", 3, 5),
        ReplicationTaskSnapshotSuccess("src", "data/src/charlie",
                                       "2018-10-01_02-00", 4, 5),
        ReplicationTaskSnapshotStart("src", "data/src/dave",
                                     "2018-10-01_02-00", 4, 5),
        ReplicationTaskSnapshotSuccess("src", "data/src/dave",
                                       "2018-10-01_02-00", 5, 5),
        ReplicationTaskSuccess("src"),
    ]

    for i, message in enumerate(result):
        call = calls[i]

        assert call[0][0].__class__ == message.__class__

        d1 = call[0][0].__dict__
        d2 = message.__dict__

        assert d1 == d2
示例#2
0
def test_replication_progress(transport):
    subprocess.call("zfs destroy -r data/src", shell=True)
    subprocess.call("zfs destroy -r data/dst", shell=True)

    subprocess.check_call("zfs create data/src", shell=True)

    subprocess.check_call("zfs create data/src/src1", shell=True)
    subprocess.check_call("zfs snapshot data/src/src1@2018-10-01_01-00",
                          shell=True)
    subprocess.check_call(
        "dd if=/dev/urandom of=/mnt/data/src/src1/blob bs=1M count=1",
        shell=True)
    subprocess.check_call("zfs snapshot data/src/src1@2018-10-01_02-00",
                          shell=True)
    subprocess.check_call("rm /mnt/data/src/src1/blob", shell=True)
    subprocess.check_call("zfs snapshot data/src/src1@2018-10-01_03-00",
                          shell=True)

    subprocess.check_call("zfs create data/src/src2", shell=True)
    subprocess.check_call("zfs snapshot data/src/src2@2018-10-01_01-00",
                          shell=True)
    subprocess.check_call("zfs snapshot data/src/src2@2018-10-01_02-00",
                          shell=True)
    subprocess.check_call("zfs snapshot data/src/src2@2018-10-01_03-00",
                          shell=True)
    subprocess.check_call("zfs snapshot data/src/src2@2018-10-01_04-00",
                          shell=True)

    definition = yaml.safe_load(
        textwrap.dedent("""\
        timezone: "UTC"

        replication-tasks:
          src:
            direction: push
            source-dataset:
            - data/src/src1
            - data/src/src2
            target-dataset: data/dst
            recursive: true
            also-include-naming-schema:
            - "%Y-%m-%d_%H-%M"
            auto: false
            retention-policy: none
            retries: 1
    """))
    definition["replication-tasks"]["src"]["transport"] = transport
    if transport["type"] == "ssh":
        definition["replication-tasks"]["src"]["speed-limit"] = 10240 * 9

    definition = Definition.from_data(definition)
    zettarepl = create_zettarepl(definition)
    zettarepl._spawn_replication_tasks(
        select_by_class(ReplicationTask, definition.tasks))
    wait_replication_tasks_to_complete(zettarepl)

    calls = [
        call for call in zettarepl.observer.call_args_list
        if call[0][0].__class__ != ReplicationTaskDataProgress
    ]

    result = [
        ReplicationTaskStart("src"),
        ReplicationTaskSnapshotStart("src", "data/src/src1",
                                     "2018-10-01_01-00", 0, 3),
        ReplicationTaskSnapshotSuccess("src", "data/src/src1",
                                       "2018-10-01_01-00", 1, 3),
        ReplicationTaskSnapshotStart("src", "data/src/src1",
                                     "2018-10-01_02-00", 1, 3),
        ReplicationTaskSnapshotSuccess("src", "data/src/src1",
                                       "2018-10-01_02-00", 2, 3),
        ReplicationTaskSnapshotStart("src", "data/src/src1",
                                     "2018-10-01_03-00", 2, 3),
        ReplicationTaskSnapshotSuccess("src", "data/src/src1",
                                       "2018-10-01_03-00", 3, 3),
        ReplicationTaskSnapshotStart("src", "data/src/src2",
                                     "2018-10-01_01-00", 3, 7),
        ReplicationTaskSnapshotSuccess("src", "data/src/src2",
                                       "2018-10-01_01-00", 4, 7),
        ReplicationTaskSnapshotStart("src", "data/src/src2",
                                     "2018-10-01_02-00", 4, 7),
        ReplicationTaskSnapshotSuccess("src", "data/src/src2",
                                       "2018-10-01_02-00", 5, 7),
        ReplicationTaskSnapshotStart("src", "data/src/src2",
                                     "2018-10-01_03-00", 5, 7),
        ReplicationTaskSnapshotSuccess("src", "data/src/src2",
                                       "2018-10-01_03-00", 6, 7),
        ReplicationTaskSnapshotStart("src", "data/src/src2",
                                     "2018-10-01_04-00", 6, 7),
        ReplicationTaskSnapshotSuccess("src", "data/src/src2",
                                       "2018-10-01_04-00", 7, 7),
        ReplicationTaskSuccess("src"),
    ]

    if transport["type"] == "ssh":
        result.insert(
            4,
            ReplicationTaskSnapshotProgress(
                "src",
                "data/src/src1",
                "2018-10-01_02-00",
                1,
                3,
                10240 * 9 * 10,  # We poll for progress every 10 seconds so
                # we would have transferred 10x speed limit
                2162784  # Empirical value
            ))

    for i, message in enumerate(result):
        call = calls[i]

        assert call[0][0].__class__ == message.__class__, calls

        d1 = call[0][0].__dict__
        d2 = message.__dict__

        if isinstance(message, ReplicationTaskSnapshotProgress):
            bytes_sent_1 = d1.pop("bytes_sent")
            bytes_total_1 = d1.pop("bytes_total")
            bytes_sent_2 = d2.pop("bytes_sent")
            bytes_total_2 = d2.pop("bytes_total")

            assert 0.8 <= bytes_sent_1 / bytes_sent_2 <= 1.2
            assert 0.8 <= bytes_total_1 / bytes_total_2 <= 1.2

        assert d1 == d2
示例#3
0
def run_replication_tasks(local_shell: LocalShell, transport: Transport, remote_shell: Shell,
                          replication_tasks: [ReplicationTask], observer=None):
    replication_tasks_parts = calculate_replication_tasks_parts(replication_tasks)

    started_replication_tasks_ids = set()
    failed_replication_tasks_ids = set()
    replication_tasks_parts_left = {
        replication_task.id: len([1
                                  for another_replication_task, source_dataset in replication_tasks_parts
                                  if another_replication_task == replication_task])
        for replication_task in replication_tasks
    }
    for replication_task, source_dataset in replication_tasks_parts:
        if replication_task.id in failed_replication_tasks_ids:
            continue

        local_context = ReplicationContext(None, local_shell)
        remote_context = ReplicationContext(transport, remote_shell)

        if replication_task.direction == ReplicationDirection.PUSH:
            src_context = local_context
            dst_context = remote_context
        elif replication_task.direction == ReplicationDirection.PULL:
            src_context = remote_context
            dst_context = local_context
        else:
            raise ValueError(f"Invalid replication direction: {replication_task.direction!r}")

        if replication_task.id not in started_replication_tasks_ids:
            notify(observer, ReplicationTaskStart(replication_task.id))
            started_replication_tasks_ids.add(replication_task.id)
        recoverable_error = None
        recoverable_sleep = 1
        for i in range(replication_task.retries):
            if recoverable_error is not None:
                logger.info("After recoverable error sleeping for %d seconds", recoverable_sleep)
                time.sleep(recoverable_sleep)
                recoverable_sleep = min(recoverable_sleep * 2, 60)
            else:
                recoverable_sleep = 1

            try:
                try:
                    run_replication_task_part(replication_task, source_dataset, src_context, dst_context, observer)
                except socket.timeout:
                    raise RecoverableReplicationError("Network connection timeout") from None
                except paramiko.ssh_exception.NoValidConnectionsError as e:
                    raise RecoverableReplicationError(str(e).replace("[Errno None] ", "")) from None
                except (IOError, OSError) as e:
                    raise RecoverableReplicationError(str(e)) from None
                replication_tasks_parts_left[replication_task.id] -= 1
                if replication_tasks_parts_left[replication_task.id] == 0:
                    notify(observer, ReplicationTaskSuccess(replication_task.id))
                break
            except RecoverableReplicationError as e:
                logger.warning("For task %r at attempt %d recoverable replication error %r", replication_task.id,
                               i + 1, e)
                recoverable_error = e
            except ReplicationError as e:
                logger.error("For task %r non-recoverable replication error %r", replication_task.id, e)
                notify(observer, ReplicationTaskError(replication_task.id, str(e)))
                failed_replication_tasks_ids.add(replication_task.id)
                break
            except Exception as e:
                logger.error("For task %r unhandled replication error %r", replication_task.id, e, exc_info=True)
                notify(observer, ReplicationTaskError(replication_task.id, str(e)))
                failed_replication_tasks_ids.add(replication_task.id)
                break
        else:
            logger.error("Failed replication task %r after %d retries", replication_task.id,
                         replication_task.retries)
            notify(observer, ReplicationTaskError(replication_task.id, str(recoverable_error)))
            failed_replication_tasks_ids.add(replication_task.id)
示例#4
0
def test_replication_progress_resume():
    subprocess.call("zfs destroy -r data/src", shell=True)
    subprocess.call("zfs destroy -r data/dst", shell=True)

    subprocess.check_call("zfs create data/src", shell=True)
    subprocess.check_call("zfs snapshot data/src@2018-10-01_01-00", shell=True)
    subprocess.check_call(
        "dd if=/dev/urandom of=/mnt/data/src/blob bs=1M count=1", shell=True)
    subprocess.check_call("zfs snapshot data/src@2018-10-01_02-00", shell=True)
    subprocess.check_call(
        "dd if=/dev/urandom of=/mnt/data/src/blob bs=1M count=1", shell=True)
    subprocess.check_call("zfs snapshot data/src@2018-10-01_03-00", shell=True)
    subprocess.check_call(
        "dd if=/dev/urandom of=/mnt/data/src/blob bs=1M count=1", shell=True)
    subprocess.check_call("zfs snapshot data/src@2018-10-01_04-00", shell=True)

    subprocess.check_call("zfs create data/dst", shell=True)
    subprocess.check_call(
        "zfs send data/src@2018-10-01_01-00 | zfs recv -s -F data/dst",
        shell=True)
    subprocess.check_call(
        "(zfs send -i data/src@2018-10-01_01-00 data/src@2018-10-01_02-00 | "
        " throttle -b 102400 | zfs recv -s -F data/dst) & "
        "sleep 1; killall zfs",
        shell=True)

    assert "receive_resume_token\t1-" in subprocess.check_output(
        "zfs get -H receive_resume_token data/dst",
        shell=True,
        encoding="utf-8")

    definition = yaml.safe_load(
        textwrap.dedent("""\
        timezone: "UTC"

        replication-tasks:
          src:
            direction: push
            transport:
              type: local
            source-dataset: data/src
            target-dataset: data/dst
            recursive: true
            also-include-naming-schema:
            - "%Y-%m-%d_%H-%M"
            auto: false
            retention-policy: none
            retries: 1
    """))

    definition = Definition.from_data(definition)
    zettarepl = create_zettarepl(definition)
    zettarepl._spawn_replication_tasks(
        select_by_class(ReplicationTask, definition.tasks))
    wait_replication_tasks_to_complete(zettarepl)

    calls = [
        call for call in zettarepl.observer.call_args_list
        if call[0][0].__class__ != ReplicationTaskDataProgress
    ]

    result = [
        ReplicationTaskStart("src"),
        ReplicationTaskSnapshotStart("src", "data/src", "2018-10-01_02-00", 0,
                                     3),
        ReplicationTaskSnapshotSuccess("src", "data/src", "2018-10-01_02-00",
                                       1, 3),
        ReplicationTaskSnapshotStart("src", "data/src", "2018-10-01_03-00", 1,
                                     3),
        ReplicationTaskSnapshotSuccess("src", "data/src", "2018-10-01_03-00",
                                       2, 3),
        ReplicationTaskSnapshotStart("src", "data/src", "2018-10-01_04-00", 2,
                                     3),
        ReplicationTaskSnapshotSuccess("src", "data/src", "2018-10-01_04-00",
                                       3, 3),
        ReplicationTaskSuccess("src"),
    ]

    for i, message in enumerate(result):
        call = calls[i]

        assert call[0][0].__class__ == message.__class__

        d1 = call[0][0].__dict__
        d2 = message.__dict__

        assert d1 == d2
示例#5
0
def run_replication_tasks(local_shell: LocalShell,
                          transport: Transport,
                          remote_shell: Shell,
                          replication_tasks: [ReplicationTask],
                          observer=None):
    contexts = defaultdict(GlobalReplicationContext)

    replication_tasks_parts = calculate_replication_tasks_parts(
        replication_tasks)

    started_replication_tasks_ids = set()
    failed_replication_tasks_ids = set()
    replication_tasks_parts_left = {
        replication_task.id: len([
            1 for another_replication_task, source_dataset in
            replication_tasks_parts
            if another_replication_task == replication_task
        ])
        for replication_task in replication_tasks
    }
    for replication_task, source_dataset in replication_tasks_parts:
        if replication_task.id in failed_replication_tasks_ids:
            continue

        local_context = ReplicationContext(contexts[replication_task], None,
                                           local_shell)
        remote_context = ReplicationContext(contexts[replication_task],
                                            transport, remote_shell)

        if replication_task.direction == ReplicationDirection.PUSH:
            src_context = local_context
            dst_context = remote_context
        elif replication_task.direction == ReplicationDirection.PULL:
            src_context = remote_context
            dst_context = local_context
        else:
            raise ValueError(
                f"Invalid replication direction: {replication_task.direction!r}"
            )

        if replication_task.id not in started_replication_tasks_ids:
            notify(observer, ReplicationTaskStart(replication_task.id))
            started_replication_tasks_ids.add(replication_task.id)
        recoverable_error = None
        recoverable_sleep = 1
        for i in range(replication_task.retries):
            if recoverable_error is not None:
                logger.info("After recoverable error sleeping for %d seconds",
                            recoverable_sleep)
                time.sleep(recoverable_sleep)
                recoverable_sleep = min(recoverable_sleep * 2, 60)
            else:
                recoverable_sleep = 1

            try:
                try:
                    run_replication_task_part(replication_task, source_dataset,
                                              src_context, dst_context,
                                              observer)
                except socket.timeout:
                    raise RecoverableReplicationError(
                        "Network connection timeout") from None
                except paramiko.ssh_exception.NoValidConnectionsError as e:
                    raise RecoverableReplicationError(
                        str(e).replace("[Errno None] ", "")) from None
                except paramiko.ssh_exception.SSHException as e:
                    if isinstance(
                            e, (paramiko.ssh_exception.AuthenticationException,
                                paramiko.ssh_exception.BadHostKeyException,
                                paramiko.ssh_exception.ProxyCommandFailure,
                                paramiko.ssh_exception.ConfigParseError)):
                        raise ReplicationError(
                            str(e).replace("[Errno None] ", "")) from None
                    else:
                        # It might be an SSH error that leaves paramiko connection in an invalid state
                        # Let's reset remote shell just in case
                        remote_shell.close()
                        raise RecoverableReplicationError(
                            str(e).replace("[Errno None] ", "")) from None
                except ExecException as e:
                    if e.returncode == 128 + signal.SIGPIPE:
                        for warning in warnings_from_zfs_success(e.stdout):
                            contexts[replication_task].add_warning(warning)
                        raise RecoverableReplicationError(
                            broken_pipe_error(e.stdout))
                    else:
                        raise
                except (IOError, OSError) as e:
                    raise RecoverableReplicationError(str(e)) from None
                replication_tasks_parts_left[replication_task.id] -= 1
                if replication_tasks_parts_left[replication_task.id] == 0:
                    notify(
                        observer,
                        ReplicationTaskSuccess(
                            replication_task.id,
                            contexts[replication_task].warnings))
                break
            except RecoverableReplicationError as e:
                logger.warning(
                    "For task %r at attempt %d recoverable replication error %r",
                    replication_task.id, i + 1, e)
                recoverable_error = e
            except ReplicationError as e:
                logger.error(
                    "For task %r non-recoverable replication error %r",
                    replication_task.id, e)
                notify(observer,
                       ReplicationTaskError(replication_task.id, str(e)))
                failed_replication_tasks_ids.add(replication_task.id)
                break
            except Exception as e:
                logger.error("For task %r unhandled replication error %r",
                             replication_task.id,
                             e,
                             exc_info=True)
                notify(observer,
                       ReplicationTaskError(replication_task.id, str(e)))
                failed_replication_tasks_ids.add(replication_task.id)
                break
        else:
            logger.error("Failed replication task %r after %d retries",
                         replication_task.id, replication_task.retries)
            notify(
                observer,
                ReplicationTaskError(replication_task.id,
                                     str(recoverable_error)))
            failed_replication_tasks_ids.add(replication_task.id)