Exemple #1
0
    def _join(self, process, startup_error):
        process.join()

        if startup_error.value:
            return

        restart = False
        with self.lock:
            if process == self.process:
                restart = True

        if restart:
            self.logger.error("Abnormal zettarepl process termination with code %r, restarting", process.exitcode)
            for k, v in self.middleware.call_sync("zettarepl.get_state").get("tasks", {}).items():
                if k.startswith("replication_") and v.get("state") in ("WAITING", "RUNNING"):
                    error = f"Abnormal zettarepl process termination with code {process.exitcode}."
                    self.middleware.call_sync("zettarepl.set_state", k, {
                        "state": "ERROR",
                        "datetime": datetime.utcnow(),
                        "error": error,
                    })
                    task_id = k[len("replication_"):]
                    for channel in self.replication_jobs_channels[task_id]:
                        channel.put(ReplicationTaskError(task_id, error))
            self.middleware.call_sync("zettarepl.start")
Exemple #2
0
    def _process_command_queue(self):
        logger = logging.getLogger("middlewared.plugins.zettarepl")

        while self.zettarepl is not None:
            command, args = self.command_queue.get()
            if command == "config":
                if "max_parallel_replication_tasks" in args:
                    self.zettarepl.max_parallel_replication_tasks = args[
                        "max_parallel_replication_tasks"]
                if "timezone" in args:
                    self.zettarepl.scheduler.tz_clock.timezone = pytz.timezone(
                        args["timezone"])
            if command == "tasks":
                definition = Definition.from_data(args, raise_on_error=False)
                self.observer_queue.put(DefinitionErrors(definition.errors))
                self.zettarepl.set_tasks(definition.tasks)
            if command == "run_task":
                class_name, task_id = args
                for task in self.zettarepl.tasks:
                    if task.__class__.__name__ == class_name and task.id == task_id:
                        logger.debug("Running task %r", task)
                        self.zettarepl.scheduler.interrupt([task])
                        break
                else:
                    logger.warning("Task %s(%r) not found", class_name,
                                   task_id)
                    self.observer_queue.put(
                        ReplicationTaskError(task_id, "Task not found"))
Exemple #3
0
def run_replication_tasks(local_shell: LocalShell, transport: Transport, remote_shell: Shell,
                          replication_tasks: [ReplicationTask], observer=None):
    replication_tasks_parts = calculate_replication_tasks_parts(replication_tasks)

    started_replication_tasks_ids = set()
    failed_replication_tasks_ids = set()
    replication_tasks_parts_left = {
        replication_task.id: len([1
                                  for another_replication_task, source_dataset in replication_tasks_parts
                                  if another_replication_task == replication_task])
        for replication_task in replication_tasks
    }
    for replication_task, source_dataset in replication_tasks_parts:
        if replication_task.id in failed_replication_tasks_ids:
            continue

        local_context = ReplicationContext(None, local_shell)
        remote_context = ReplicationContext(transport, remote_shell)

        if replication_task.direction == ReplicationDirection.PUSH:
            src_context = local_context
            dst_context = remote_context
        elif replication_task.direction == ReplicationDirection.PULL:
            src_context = remote_context
            dst_context = local_context
        else:
            raise ValueError(f"Invalid replication direction: {replication_task.direction!r}")

        if replication_task.id not in started_replication_tasks_ids:
            notify(observer, ReplicationTaskStart(replication_task.id))
            started_replication_tasks_ids.add(replication_task.id)
        recoverable_error = None
        recoverable_sleep = 1
        for i in range(replication_task.retries):
            if recoverable_error is not None:
                logger.info("After recoverable error sleeping for %d seconds", recoverable_sleep)
                time.sleep(recoverable_sleep)
                recoverable_sleep = min(recoverable_sleep * 2, 60)
            else:
                recoverable_sleep = 1

            try:
                try:
                    run_replication_task_part(replication_task, source_dataset, src_context, dst_context, observer)
                except socket.timeout:
                    raise RecoverableReplicationError("Network connection timeout") from None
                except paramiko.ssh_exception.NoValidConnectionsError as e:
                    raise RecoverableReplicationError(str(e).replace("[Errno None] ", "")) from None
                except (IOError, OSError) as e:
                    raise RecoverableReplicationError(str(e)) from None
                replication_tasks_parts_left[replication_task.id] -= 1
                if replication_tasks_parts_left[replication_task.id] == 0:
                    notify(observer, ReplicationTaskSuccess(replication_task.id))
                break
            except RecoverableReplicationError as e:
                logger.warning("For task %r at attempt %d recoverable replication error %r", replication_task.id,
                               i + 1, e)
                recoverable_error = e
            except ReplicationError as e:
                logger.error("For task %r non-recoverable replication error %r", replication_task.id, e)
                notify(observer, ReplicationTaskError(replication_task.id, str(e)))
                failed_replication_tasks_ids.add(replication_task.id)
                break
            except Exception as e:
                logger.error("For task %r unhandled replication error %r", replication_task.id, e, exc_info=True)
                notify(observer, ReplicationTaskError(replication_task.id, str(e)))
                failed_replication_tasks_ids.add(replication_task.id)
                break
        else:
            logger.error("Failed replication task %r after %d retries", replication_task.id,
                         replication_task.retries)
            notify(observer, ReplicationTaskError(replication_task.id, str(recoverable_error)))
            failed_replication_tasks_ids.add(replication_task.id)
Exemple #4
0
def run_replication_tasks(local_shell: LocalShell,
                          transport: Transport,
                          remote_shell: Shell,
                          replication_tasks: [ReplicationTask],
                          observer=None):
    contexts = defaultdict(GlobalReplicationContext)

    replication_tasks_parts = calculate_replication_tasks_parts(
        replication_tasks)

    started_replication_tasks_ids = set()
    failed_replication_tasks_ids = set()
    replication_tasks_parts_left = {
        replication_task.id: len([
            1 for another_replication_task, source_dataset in
            replication_tasks_parts
            if another_replication_task == replication_task
        ])
        for replication_task in replication_tasks
    }
    for replication_task, source_dataset in replication_tasks_parts:
        if replication_task.id in failed_replication_tasks_ids:
            continue

        local_context = ReplicationContext(contexts[replication_task], None,
                                           local_shell)
        remote_context = ReplicationContext(contexts[replication_task],
                                            transport, remote_shell)

        if replication_task.direction == ReplicationDirection.PUSH:
            src_context = local_context
            dst_context = remote_context
        elif replication_task.direction == ReplicationDirection.PULL:
            src_context = remote_context
            dst_context = local_context
        else:
            raise ValueError(
                f"Invalid replication direction: {replication_task.direction!r}"
            )

        if replication_task.id not in started_replication_tasks_ids:
            notify(observer, ReplicationTaskStart(replication_task.id))
            started_replication_tasks_ids.add(replication_task.id)
        recoverable_error = None
        recoverable_sleep = 1
        for i in range(replication_task.retries):
            if recoverable_error is not None:
                logger.info("After recoverable error sleeping for %d seconds",
                            recoverable_sleep)
                time.sleep(recoverable_sleep)
                recoverable_sleep = min(recoverable_sleep * 2, 60)
            else:
                recoverable_sleep = 1

            try:
                try:
                    run_replication_task_part(replication_task, source_dataset,
                                              src_context, dst_context,
                                              observer)
                except socket.timeout:
                    raise RecoverableReplicationError(
                        "Network connection timeout") from None
                except paramiko.ssh_exception.NoValidConnectionsError as e:
                    raise RecoverableReplicationError(
                        str(e).replace("[Errno None] ", "")) from None
                except paramiko.ssh_exception.SSHException as e:
                    if isinstance(
                            e, (paramiko.ssh_exception.AuthenticationException,
                                paramiko.ssh_exception.BadHostKeyException,
                                paramiko.ssh_exception.ProxyCommandFailure,
                                paramiko.ssh_exception.ConfigParseError)):
                        raise ReplicationError(
                            str(e).replace("[Errno None] ", "")) from None
                    else:
                        # It might be an SSH error that leaves paramiko connection in an invalid state
                        # Let's reset remote shell just in case
                        remote_shell.close()
                        raise RecoverableReplicationError(
                            str(e).replace("[Errno None] ", "")) from None
                except ExecException as e:
                    if e.returncode == 128 + signal.SIGPIPE:
                        for warning in warnings_from_zfs_success(e.stdout):
                            contexts[replication_task].add_warning(warning)
                        raise RecoverableReplicationError(
                            broken_pipe_error(e.stdout))
                    else:
                        raise
                except (IOError, OSError) as e:
                    raise RecoverableReplicationError(str(e)) from None
                replication_tasks_parts_left[replication_task.id] -= 1
                if replication_tasks_parts_left[replication_task.id] == 0:
                    notify(
                        observer,
                        ReplicationTaskSuccess(
                            replication_task.id,
                            contexts[replication_task].warnings))
                break
            except RecoverableReplicationError as e:
                logger.warning(
                    "For task %r at attempt %d recoverable replication error %r",
                    replication_task.id, i + 1, e)
                recoverable_error = e
            except ReplicationError as e:
                logger.error(
                    "For task %r non-recoverable replication error %r",
                    replication_task.id, e)
                notify(observer,
                       ReplicationTaskError(replication_task.id, str(e)))
                failed_replication_tasks_ids.add(replication_task.id)
                break
            except Exception as e:
                logger.error("For task %r unhandled replication error %r",
                             replication_task.id,
                             e,
                             exc_info=True)
                notify(observer,
                       ReplicationTaskError(replication_task.id, str(e)))
                failed_replication_tasks_ids.add(replication_task.id)
                break
        else:
            logger.error("Failed replication task %r after %d retries",
                         replication_task.id, replication_task.retries)
            notify(
                observer,
                ReplicationTaskError(replication_task.id,
                                     str(recoverable_error)))
            failed_replication_tasks_ids.add(replication_task.id)