Ejemplo n.º 1
0
class PGHoard:
    def __init__(self, config_path):
        self.metrics = None
        self.log = logging.getLogger("pghoard")
        self.log_level = None
        self.running = True
        self.config_path = config_path
        self.compression_queue = Queue()
        self.transfer_queue = Queue()
        self.syslog_handler = None
        self.basebackups = {}
        self.basebackups_callbacks = {}
        self.receivexlogs = {}
        self.compressors = []
        self.walreceivers = {}
        self.transfer_agents = []
        self.config = {}
        self.mp_manager = None
        self.site_transfers = {}
        self.state = {
            "backup_sites": {},
            "startup_time": datetime.datetime.utcnow().isoformat(),
        }
        self.transfer_agent_state = {}  # shared among transfer agents
        self.load_config()
        if self.config["transfer"]["thread_count"] > 1:
            self.mp_manager = multiprocessing.Manager()

        if not os.path.exists(self.config["backup_location"]):
            os.makedirs(self.config["backup_location"])

        # Read transfer_agent_state from state file if available so that there's no disruption
        # in the metrics we send out as a result of process restart
        state_file_path = self.config["json_state_file_path"]
        if os.path.exists(state_file_path):
            with open(state_file_path, "r") as fp:
                state = json.load(fp)
                self.transfer_agent_state = state.get(
                    "transfer_agent_state") or {}

        signal.signal(signal.SIGHUP, self.load_config)
        signal.signal(signal.SIGINT, self.quit)
        signal.signal(signal.SIGTERM, self.quit)
        self.time_of_last_backup_check = {}
        self.requested_basebackup_sites = set()

        self.inotify = InotifyWatcher(self.compression_queue)
        self.webserver = WebServer(self.config,
                                   self.requested_basebackup_sites,
                                   self.compression_queue, self.transfer_queue,
                                   self.metrics)

        for _ in range(self.config["compression"]["thread_count"]):
            compressor = CompressorThread(
                config_dict=self.config,
                compression_queue=self.compression_queue,
                transfer_queue=self.transfer_queue,
                metrics=self.metrics)
            self.compressors.append(compressor)

        for _ in range(self.config["transfer"]["thread_count"]):
            ta = TransferAgent(config=self.config,
                               compression_queue=self.compression_queue,
                               mp_manager=self.mp_manager,
                               transfer_queue=self.transfer_queue,
                               metrics=self.metrics,
                               shared_state_dict=self.transfer_agent_state)
            self.transfer_agents.append(ta)

        logutil.notify_systemd("READY=1")
        self.log.info("pghoard initialized, own_hostname: %r, cwd: %r",
                      socket.gethostname(), os.getcwd())

    def check_pg_versions_ok(self, site, pg_version_server, command):
        if pg_version_server is None:
            # remote pg version not available, don't create version alert in this case
            return False
        if not pg_version_server:
            self.log.error(
                "pghoard does not support versions earlier than 9.3, found: %r",
                pg_version_server)
            create_alert_file(self.config, "version_unsupported_error")
            return False
        pg_version_client = self.config["backup_sites"][site][command +
                                                              "_version"]
        if pg_version_server // 100 != pg_version_client // 100:
            self.log.error("Server version: %r does not match %s version: %r",
                           pg_version_server, self.config[command + "_path"],
                           pg_version_client)
            create_alert_file(self.config, "version_mismatch_error")
            return False
        return True

    def create_basebackup(self,
                          site,
                          connection_info,
                          basebackup_path,
                          callback_queue=None,
                          metadata=None):
        connection_string, _ = replication_connection_string_and_slot_using_pgpass(
            connection_info)
        pg_version_server = self.check_pg_server_version(
            connection_string, site)
        if not self.check_pg_versions_ok(site, pg_version_server,
                                         "pg_basebackup"):
            if callback_queue:
                callback_queue.put({"success": False})
            return

        thread = PGBaseBackup(
            config=self.config,
            site=site,
            connection_info=connection_info,
            basebackup_path=basebackup_path,
            compression_queue=self.compression_queue,
            transfer_queue=self.transfer_queue,
            callback_queue=callback_queue,
            pg_version_server=pg_version_server,
            metrics=self.metrics,
            metadata=metadata,
        )
        thread.start()
        self.basebackups[site] = thread

    def check_pg_server_version(self, connection_string, site):
        if "pg_version" in self.config["backup_sites"][site]:
            return self.config["backup_sites"][site]["pg_version"]

        pg_version = None
        try:
            with closing(psycopg2.connect(connection_string)) as c:
                pg_version = c.server_version  # pylint: disable=no-member
                # Cache pg_version so we don't have to query it again, note that this means that for major
                # version upgrades you want to restart pghoard.
                self.config["backup_sites"][site]["pg_version"] = pg_version
        except psycopg2.OperationalError as ex:
            self.log.warning("%s (%s) connecting to DB at: %r",
                             ex.__class__.__name__, ex, connection_string)
            if "password authentication" in str(
                    ex) or "authentication failed" in str(ex):
                create_alert_file(self.config, "authentication_error")
            else:
                create_alert_file(self.config, "configuration_error")
        except Exception as ex:  # log all errors and return None; pylint: disable=broad-except
            self.log.exception("Problem in getting PG server version")
            self.metrics.unexpected_exception(ex,
                                              where="check_pg_server_version")
        return pg_version

    def receivexlog_listener(self, site, connection_info, wal_directory):
        connection_string, slot = replication_connection_string_and_slot_using_pgpass(
            connection_info)
        pg_version_server = self.check_pg_server_version(
            connection_string, site)
        if not self.check_pg_versions_ok(site, pg_version_server,
                                         "pg_receivexlog"):
            return

        self.inotify.add_watch(wal_directory)
        thread = PGReceiveXLog(config=self.config,
                               connection_string=connection_string,
                               wal_location=wal_directory,
                               site=site,
                               slot=slot,
                               pg_version_server=pg_version_server)
        thread.start()
        self.receivexlogs[site] = thread

    def start_walreceiver(self, site, chosen_backup_node, last_flushed_lsn):
        connection_string, slot = replication_connection_string_and_slot_using_pgpass(
            chosen_backup_node)
        pg_version_server = self.check_pg_server_version(
            connection_string, site)
        if not WALReceiver:
            self.log.error(
                "Could not import WALReceiver, incorrect psycopg2 version?")
            return

        thread = WALReceiver(config=self.config,
                             connection_string=connection_string,
                             compression_queue=self.compression_queue,
                             replication_slot=slot,
                             pg_version_server=pg_version_server,
                             site=site,
                             last_flushed_lsn=last_flushed_lsn,
                             metrics=self.metrics)
        thread.start()
        self.walreceivers[site] = thread

    def create_backup_site_paths(self, site):
        site_path = os.path.join(self.config["backup_location"],
                                 self.config["backup_sites"][site]["prefix"])
        xlog_path = os.path.join(site_path, "xlog")
        basebackup_path = os.path.join(site_path, "basebackup")

        paths_to_create = [
            site_path,
            xlog_path,
            xlog_path + "_incoming",
            basebackup_path,
            basebackup_path + "_incoming",
        ]

        for path in paths_to_create:
            if not os.path.exists(path):
                os.makedirs(path)

        return xlog_path, basebackup_path

    def delete_remote_wal_before(self, wal_segment, site, pg_version):
        self.log.info(
            "Starting WAL deletion from: %r before: %r, pg_version: %r", site,
            wal_segment, pg_version)
        storage = self.site_transfers.get(site)
        valid_timeline = True
        tli, log, seg = wal.name_to_tli_log_seg(wal_segment)
        while True:
            if valid_timeline:
                # Decrement one segment if we're on a valid timeline
                if seg == 0 and log == 0:
                    break
                seg, log = wal.get_previous_wal_on_same_timeline(
                    seg, log, pg_version)
            wal_path = os.path.join(
                self.config["backup_sites"][site]["prefix"], "xlog",
                wal.name_for_tli_log_seg(tli, log, seg))
            self.log.debug("Deleting wal_file: %r", wal_path)
            try:
                storage.delete_key(wal_path)
                valid_timeline = True
            except FileNotFoundFromStorageError:
                if not valid_timeline or tli <= 1:
                    # if we didn't find any WALs to delete on this timeline or we're already at
                    # timeline 1 there's no need or possibility to try older timelines, break.
                    self.log.info("Could not delete wal_file: %r, returning",
                                  wal_path)
                    break
                # let's try the same segment number on a previous timeline, but flag that timeline
                # as "invalid" until we're able to delete at least one segment on it.
                valid_timeline = False
                tli -= 1
                self.log.info(
                    "Could not delete wal_file: %r, trying the same segment on a previous "
                    "timeline (%s)", wal_path,
                    wal.name_for_tli_log_seg(tli, log, seg))
            except Exception as ex:  # FIXME: don't catch all exceptions; pylint: disable=broad-except
                self.log.exception("Problem deleting: %r", wal_path)
                self.metrics.unexpected_exception(
                    ex, where="delete_remote_wal_before")

    def delete_remote_basebackup(self, site, basebackup, metadata):
        start_time = time.monotonic()
        storage = self.site_transfers.get(site)
        main_backup_key = os.path.join(
            self.config["backup_sites"][site]["prefix"], "basebackup",
            basebackup)
        basebackup_data_files = [main_backup_key]

        if metadata.get("format") == "pghoard-bb-v2":
            bmeta_compressed = storage.get_contents_to_string(
                main_backup_key)[0]
            with rohmufile.file_reader(fileobj=io.BytesIO(bmeta_compressed),
                                       metadata=metadata,
                                       key_lookup=config.key_lookup_for_site(
                                           self.config, site)) as input_obj:
                bmeta = extract_pghoard_bb_v2_metadata(input_obj)
                self.log.debug("PGHoard chunk metadata: %r", bmeta)
                for chunk in bmeta["chunks"]:
                    basebackup_data_files.append(
                        os.path.join(
                            self.config["backup_sites"][site]["prefix"],
                            "basebackup_chunk",
                            chunk["chunk_filename"],
                        ))

        self.log.debug("Deleting basebackup datafiles: %r",
                       ', '.join(basebackup_data_files))
        for obj_key in basebackup_data_files:
            try:
                storage.delete_key(obj_key)
            except FileNotFoundFromStorageError:
                self.log.info("Tried to delete non-existent basebackup %r",
                              obj_key)
            except Exception as ex:  # FIXME: don't catch all exceptions; pylint: disable=broad-except
                self.log.exception("Problem deleting: %r", obj_key)
                self.metrics.unexpected_exception(
                    ex, where="delete_remote_basebackup")
        self.log.info("Deleted basebackup datafiles: %r, took: %.2fs",
                      ', '.join(basebackup_data_files),
                      time.monotonic() - start_time)

    def get_remote_basebackups_info(self, site):
        storage = self.site_transfers.get(site)
        if not storage:
            storage_config = get_object_storage_config(self.config, site)
            storage = get_transfer(storage_config)
            self.site_transfers[site] = storage

        site_config = self.config["backup_sites"][site]
        results = storage.list_path(
            os.path.join(site_config["prefix"], "basebackup"))
        for entry in results:
            self.patch_basebackup_info(entry=entry, site_config=site_config)

        results.sort(key=lambda entry: entry["metadata"]["start-time"])
        return results

    def patch_basebackup_info(self, *, entry, site_config):
        # drop path from resulting list and convert timestamps
        entry["name"] = os.path.basename(entry["name"])
        metadata = entry["metadata"]
        metadata["start-time"] = dates.parse_timestamp(metadata["start-time"])
        # If backup was created by old PGHoard version some fields related to backup scheduling might be missing.
        # Set "best guess" values for those fields here to simplify logic elsewhere.
        if "backup-decision-time" in metadata:
            metadata["backup-decision-time"] = dates.parse_timestamp(
                metadata["backup-decision-time"])
        else:
            metadata["backup-decision-time"] = metadata["start-time"]
        # Backups are usually scheduled
        if "backup-reason" not in metadata:
            metadata["backup-reason"] = "scheduled"
        # Calculate normalized backup time based on start time if missing
        if "normalized-backup-time" not in metadata:
            metadata[
                "normalized-backup-time"] = self.get_normalized_backup_time(
                    site_config, now=metadata["start-time"])

    def determine_backups_to_delete(self, *, basebackups, site_config):
        """Returns the basebackups in the given list that need to be deleted based on the given site configuration.
        Note that `basebackups` is edited in place: any basebackups that need to be deleted are removed from it."""
        allowed_basebackup_count = site_config["basebackup_count"]
        if allowed_basebackup_count is None:
            allowed_basebackup_count = len(basebackups)

        basebackups_to_delete = []
        while len(basebackups) > allowed_basebackup_count:
            self.log.warning(
                "Too many basebackups: %d > %d, %r, starting to get rid of %r",
                len(basebackups), allowed_basebackup_count, basebackups,
                basebackups[0]["name"])
            basebackups_to_delete.append(basebackups.pop(0))

        backup_interval = datetime.timedelta(
            hours=site_config["basebackup_interval_hours"])
        min_backups = site_config["basebackup_count_min"]
        max_age_days = site_config.get("basebackup_age_days_max")
        current_time = datetime.datetime.now(datetime.timezone.utc)
        if max_age_days and min_backups > 0:
            while basebackups and len(basebackups) > min_backups:
                # For age checks we treat the age as current_time - (backup_start_time + backup_interval). So when
                # backup interval is set to 24 hours a backup started 2.5 days ago would be considered to be 1.5 days old.
                completed_at = basebackups[0]["metadata"][
                    "start-time"] + backup_interval
                backup_age = current_time - completed_at
                # timedelta would have direct `days` attribute but that's an integer rounded down. We want a float
                # so that we can react immediately when age is too old
                backup_age_days = backup_age.total_seconds(
                ) / 60.0 / 60.0 / 24.0
                if backup_age_days > max_age_days:
                    self.log.warning(
                        "Basebackup %r too old: %.3f > %.3f, %r, starting to get rid of it",
                        basebackups[0]["name"], backup_age_days, max_age_days,
                        basebackups)
                    basebackups_to_delete.append(basebackups.pop(0))
                else:
                    break

        return basebackups_to_delete

    def refresh_backup_list_and_delete_old(self, site):
        """Look up basebackups from the object store, prune any extra
        backups and return the datetime of the latest backup."""
        basebackups = self.get_remote_basebackups_info(site)
        self.log.debug("Found %r basebackups", basebackups)

        site_config = self.config["backup_sites"][site]
        # Never delete backups from a recovery site. This check is already elsewhere as well
        # but still check explicitly here to ensure we certainly won't delete anything unexpectedly
        if site_config["active"]:
            basebackups_to_delete = self.determine_backups_to_delete(
                basebackups=basebackups, site_config=site_config)

            for basebackup_to_be_deleted in basebackups_to_delete:
                pg_version = basebackup_to_be_deleted["metadata"].get(
                    "pg-version")
                last_wal_segment_still_needed = 0
                if basebackups:
                    last_wal_segment_still_needed = basebackups[0]["metadata"][
                        "start-wal-segment"]

                if last_wal_segment_still_needed:
                    self.delete_remote_wal_before(
                        last_wal_segment_still_needed, site, pg_version)
                self.delete_remote_basebackup(
                    site, basebackup_to_be_deleted["name"],
                    basebackup_to_be_deleted["metadata"])
        self.state["backup_sites"][site]["basebackups"] = basebackups

    def get_normalized_backup_time(self, site_config, *, now=None):
        """Returns the closest historical backup time that current time matches to (or current time if it matches).
        E.g. if backup hour is 13, backup minute is 50, current time is 15:40 and backup interval is 60 minutes,
        the return value is 14:50 today. If backup hour and minute are as before, backup interval is 1440 and
        current time is 13:45 the return value is 13:50 yesterday."""
        backup_hour = site_config.get("basebackup_hour")
        backup_minute = site_config.get("basebackup_minute")
        backup_interval_hours = site_config.get("basebackup_interval_hours")
        if backup_hour is None or backup_minute is None or backup_interval_hours is None:
            return None

        if not now:
            now = datetime.datetime.now(datetime.timezone.utc)
        normalized = now
        if normalized.hour < backup_hour or (
                normalized.hour == backup_hour
                and normalized.minute < backup_minute):
            normalized = normalized - datetime.timedelta(days=1)
        normalized = normalized.replace(hour=backup_hour,
                                        minute=backup_minute,
                                        second=0,
                                        microsecond=0)
        while normalized + datetime.timedelta(
                hours=backup_interval_hours) < now:
            normalized = normalized + datetime.timedelta(
                hours=backup_interval_hours)
        return normalized.isoformat()

    def set_state_defaults(self, site):
        if site not in self.state["backup_sites"]:
            self.state["backup_sites"][site] = {"basebackups": []}

    def startup_walk_for_missed_files(self):
        """Check xlog and xlog_incoming directories for files that receivexlog has received but not yet
        compressed as well as the files we have compressed but not yet uploaded and process them."""
        for site in self.config["backup_sites"]:
            compressed_xlog_path, _ = self.create_backup_site_paths(site)
            uncompressed_xlog_path = compressed_xlog_path + "_incoming"

            # Process uncompressed files (ie WAL pg_receivexlog received)
            for filename in os.listdir(uncompressed_xlog_path):
                full_path = os.path.join(uncompressed_xlog_path, filename)
                if wal.PARTIAL_WAL_RE.match(filename):
                    # pg_receivewal may have been in the middle of storing WAL file when PGHoard was stopped.
                    # If the file is 0 or 16 MiB in size it will continue normally but in some cases the file can be
                    # incomplete causing pg_receivewal to halt processing. Truncating the file to zero bytes correctly
                    # makes it continue streaming from the beginning of that segment.
                    file_size = os.stat(full_path).st_size
                    if file_size in {0, wal.WAL_SEG_SIZE}:
                        self.log.info("Found partial file %r, size %d bytes",
                                      full_path, file_size)
                    else:
                        self.log.warning(
                            "Found partial file %r with unexpected size %d, truncating to zero bytes",
                            full_path, file_size)
                        # Make a copy of the file for safekeeping. The data should still be available on PG
                        # side but just in case it isn't the incomplete segment could still be relevant for
                        # manual processing later
                        shutil.copyfile(full_path, full_path + "_incomplete")
                        self.metrics.increase(
                            "pghoard.incomplete_partial_wal_segment")
                        os.truncate(full_path, 0)
                    continue

                if not wal.WAL_RE.match(
                        filename) and not wal.TIMELINE_RE.match(filename):
                    self.log.warning(
                        "Found invalid file %r from incoming xlog directory",
                        full_path)
                    continue

                compression_event = {
                    "delete_file_after_compression": True,
                    "full_path": full_path,
                    "site": site,
                    "src_path": "{}.partial",
                    "type": "MOVE",
                }
                self.log.debug(
                    "Found: %r when starting up, adding to compression queue",
                    compression_event)
                self.compression_queue.put(compression_event)

            # Process compressed files (ie things we've processed but not yet uploaded)
            for filename in os.listdir(compressed_xlog_path):
                if filename.endswith(".metadata"):
                    continue  # silently ignore .metadata files, they're expected and processed below
                full_path = os.path.join(compressed_xlog_path, filename)
                metadata_path = full_path + ".metadata"
                is_xlog = wal.WAL_RE.match(filename)
                is_timeline = wal.TIMELINE_RE.match(filename)
                if not ((is_xlog or is_timeline)
                        and os.path.exists(metadata_path)):
                    self.log.warning(
                        "Found invalid file %r from compressed xlog directory",
                        full_path)
                    continue
                with open(metadata_path, "r") as fp:
                    metadata = json.load(fp)

                transfer_event = {
                    "file_size": os.path.getsize(full_path),
                    "filetype": "xlog" if is_xlog else "timeline",
                    "local_path": full_path,
                    "metadata": metadata,
                    "site": site,
                    "type": "UPLOAD",
                }
                self.log.debug(
                    "Found: %r when starting up, adding to transfer queue",
                    transfer_event)
                self.transfer_queue.put(transfer_event)

    def start_threads_on_startup(self):
        # Startup threads
        self.inotify.start()
        self.webserver.start()
        for compressor in self.compressors:
            compressor.start()
        for ta in self.transfer_agents:
            ta.start()

    def _cleanup_inactive_receivexlogs(self, site):
        if site in self.receivexlogs:
            if not self.receivexlogs[site].running:
                if self.receivexlogs[site].is_alive():
                    self.receivexlogs[site].join()
                del self.receivexlogs[site]

    def handle_site(self, site, site_config):
        self.set_state_defaults(site)
        xlog_path, basebackup_path = self.create_backup_site_paths(site)

        if not site_config["active"]:
            return  # If a site has been marked inactive, don't bother checking anything

        self._cleanup_inactive_receivexlogs(site)

        chosen_backup_node = random.choice(site_config["nodes"])

        if site not in self.receivexlogs and site not in self.walreceivers:
            if site_config["active_backup_mode"] == "pg_receivexlog":
                self.receivexlog_listener(site, chosen_backup_node,
                                          xlog_path + "_incoming")
            elif site_config["active_backup_mode"] == "walreceiver":
                state_file_path = self.config["json_state_file_path"]
                walreceiver_state = {}
                with suppress(FileNotFoundError):
                    with open(state_file_path, "r") as fp:
                        old_state_file = json.load(fp)
                        walreceiver_state = old_state_file.get(
                            "walreceivers", {}).get(site, {})
                self.start_walreceiver(
                    site=site,
                    chosen_backup_node=chosen_backup_node,
                    last_flushed_lsn=walreceiver_state.get("last_flushed_lsn"))

        last_check_time = self.time_of_last_backup_check.get(site)
        if not last_check_time or (time.monotonic() -
                                   self.time_of_last_backup_check[site]) > 300:
            self.refresh_backup_list_and_delete_old(site)
            self.time_of_last_backup_check[site] = time.monotonic()

        # check if a basebackup is running, or if a basebackup has just completed
        if site in self.basebackups:
            try:
                result = self.basebackups_callbacks[site].get(block=False)
            except Empty:
                # previous basebackup (or its compression and upload) still in progress
                return
            if self.basebackups[site].is_alive():
                self.basebackups[site].join()
            del self.basebackups[site]
            del self.basebackups_callbacks[site]
            self.log.debug("Basebackup has finished for %r: %r", site, result)
            self.refresh_backup_list_and_delete_old(site)
            self.time_of_last_backup_check[site] = time.monotonic()

        metadata = self.get_new_backup_details(site=site,
                                               site_config=site_config)
        if metadata and not os.path.exists(
                self.config["maintenance_mode_file"]):
            self.basebackups_callbacks[site] = Queue()
            self.create_basebackup(site, chosen_backup_node, basebackup_path,
                                   self.basebackups_callbacks[site], metadata)

    def get_new_backup_details(self, *, now=None, site, site_config):
        """Returns metadata to associate with new backup that needs to be created or None in case no backup should
        be created at this time"""
        if not now:
            now = datetime.datetime.now(datetime.timezone.utc)
        basebackups = self.state["backup_sites"][site]["basebackups"]
        backup_hour = site_config.get("basebackup_hour")
        backup_minute = site_config.get("basebackup_minute")
        backup_reason = None
        normalized_backup_time = self.get_normalized_backup_time(site_config,
                                                                 now=now)

        if site in self.requested_basebackup_sites:
            self.log.info("Creating a new basebackup for %r due to request",
                          site)
            self.requested_basebackup_sites.discard(site)
            backup_reason = "requested"
        elif site_config["basebackup_interval_hours"] is None:
            # Basebackups are disabled for this site (but they can still be requested over the API.)
            pass
        elif not basebackups:
            self.log.info(
                "Creating a new basebackup for %r because there are currently none",
                site)
            backup_reason = "scheduled"
        elif backup_hour is not None and backup_minute is not None:
            most_recent_scheduled = None
            last_normalized_backup_time = basebackups[-1]["metadata"][
                "normalized-backup-time"]
            scheduled_backups = [
                backup for backup in basebackups
                if backup["metadata"]["backup-reason"] == "scheduled"
            ]
            if scheduled_backups:
                most_recent_scheduled = scheduled_backups[-1]["metadata"][
                    "backup-decision-time"]

            # Don't create new backup unless at least half of interval has elapsed since scheduled last backup. Otherwise
            # we would end up creating a new backup each time when backup hour/minute changes, which is typically undesired.
            # With the "half of interval" check the backup time will quickly drift towards the selected time without backup
            # spamming in case of repeated setting changes.
            delta = datetime.timedelta(
                hours=site_config["basebackup_interval_hours"] / 2)
            normalized_time_changed = (last_normalized_backup_time !=
                                       normalized_backup_time)
            last_scheduled_isnt_too_recent = (
                not most_recent_scheduled
                or most_recent_scheduled + delta <= now)
            if normalized_time_changed and last_scheduled_isnt_too_recent:
                self.log.info(
                    "Normalized backup time %r differs from previous %r, creating new basebackup",
                    normalized_backup_time, last_normalized_backup_time)
                backup_reason = "scheduled"
        else:
            # No backup schedule defined, create new backup if backup interval hours has passed since last backup
            time_of_last_backup = basebackups[-1]["metadata"]["start-time"]
            delta_since_last_backup = now - time_of_last_backup
            if delta_since_last_backup >= datetime.timedelta(
                    hours=site_config["basebackup_interval_hours"]):
                self.log.info(
                    "Creating a new basebackup for %r by schedule (%s from previous)",
                    site, delta_since_last_backup)
                backup_reason = "scheduled"

        if not backup_reason:
            return None

        return {
            # The time when it was decided that a new backup should be taken. This is usually almost the same as
            # start-time but if taking the backup gets delayed for any reason this is more accurate for deciding
            # when next backup should be taken
            "backup-decision-time": now.isoformat(),
            # Whether this backup was taken due to schedule or explicit request. Affects scheduling of next backup
            # (explicitly requested backups don't affect the schedule)
            "backup-reason": backup_reason,
            # The closest backup schedule time this backup matches to (if schedule has been defined)
            "normalized-backup-time": normalized_backup_time,
        }

    def run(self):
        self.start_threads_on_startup()
        self.startup_walk_for_missed_files()
        while self.running:
            try:
                for site, site_config in self.config["backup_sites"].items():
                    self.handle_site(site, site_config)
                self.write_backup_state_to_json_file()
            except subprocess.CalledProcessError as ex:
                self.log.error("main loop: %s: %s, retrying...",
                               ex.__class__.__name__, ex)
            except Exception as ex:  # pylint: disable=broad-except
                self.log.exception("Unexpected exception in PGHoard main loop")
                self.metrics.unexpected_exception(ex, where="pghoard_run")
            time.sleep(5.0)

    def write_backup_state_to_json_file(self):
        """Periodically write a JSON state file to disk"""
        start_time = time.time()
        state_file_path = self.config["json_state_file_path"]
        self.state["walreceivers"] = {
            key: {
                "latest_activity": value.latest_activity,
                "running": value.running,
                "last_flushed_lsn": value.last_flushed_lsn
            }
            for key, value in self.walreceivers.items()
        }
        self.state["pg_receivexlogs"] = {
            key: {
                "latest_activity": value.latest_activity,
                "running": value.running
            }
            for key, value in self.receivexlogs.items()
        }
        self.state["pg_basebackups"] = {
            key: {
                "latest_activity": value.latest_activity,
                "running": value.running
            }
            for key, value in self.basebackups.items()
        }
        self.state["compressors"] = [
            compressor.state for compressor in self.compressors
        ]
        # All transfer agents share the same state, no point in writing it multiple times
        self.state["transfer_agent_state"] = self.transfer_agent_state
        self.state["queues"] = {
            "compression_queue": self.compression_queue.qsize(),
            "transfer_queue": self.transfer_queue.qsize(),
        }
        self.state[
            "served_files"] = self.webserver.get_most_recently_served_files(
            ) if self.webserver else {}
        self.log.debug("Writing JSON state file to %r", state_file_path)
        write_json_file(state_file_path, self.state)
        self.log.debug("Wrote JSON state file to disk, took %.4fs",
                       time.time() - start_time)

    def load_config(self, _signal=None, _frame=None):  # pylint: disable=unused-argument
        self.log.debug("Loading JSON config from: %r, signal: %r",
                       self.config_path, _signal)
        try:
            new_config = config.read_json_config_file(self.config_path)
        except (InvalidConfigurationError, subprocess.CalledProcessError,
                UnicodeDecodeError) as ex:
            self.log.exception("Invalid config file %r: %s: %s",
                               self.config_path, ex.__class__.__name__, ex)
            # if we were called by a signal handler we'll ignore (and log)
            # the error and hope the user fixes the configuration before
            # restarting pghoard.
            if _signal is not None:
                return
            if isinstance(ex, InvalidConfigurationError):
                raise
            raise InvalidConfigurationError(self.config_path)

        self.config = new_config
        if self.config.get("syslog") and not self.syslog_handler:
            self.syslog_handler = logutil.set_syslog_handler(
                address=self.config.get("syslog_address", "/dev/log"),
                facility=self.config.get("syslog_facility", "local2"),
                logger=logging.getLogger(),
            )
        # NOTE: getLevelName() also converts level names to numbers
        self.log_level = logging.getLevelName(self.config["log_level"])
        try:
            logging.getLogger().setLevel(self.log_level)
        except ValueError:
            self.log.exception("Problem with log_level: %r", self.log_level)

        # Setup monitoring clients
        self.metrics = metrics.Metrics(
            statsd=self.config.get("statsd", None),
            pushgateway=self.config.get("pushgateway", None),
            prometheus=self.config.get("prometheus", None))

        for thread in self._get_all_threads():
            thread.config = new_config
            thread.site_transfers = {}

        self.log.debug("Loaded config: %r from: %r", self.config,
                       self.config_path)

    def _get_all_threads(self):
        all_threads = []
        if hasattr(self, "webserver"
                   ):  # on first config load webserver isn't initialized yet
            all_threads.append(self.webserver)
        all_threads.extend(self.basebackups.values())
        all_threads.extend(self.receivexlogs.values())
        all_threads.extend(self.walreceivers.values())
        all_threads.extend(self.compressors)
        all_threads.extend(self.transfer_agents)
        return all_threads

    def quit(self, _signal=None, _frame=None):  # pylint: disable=unused-argument
        self.log.warning("Quitting, signal: %r", _signal)
        self.running = False
        self.inotify.running = False
        all_threads = self._get_all_threads()
        for t in all_threads:
            t.running = False
        # Write state file in the end so we get the last known state
        self.write_backup_state_to_json_file()
        for t in all_threads:
            if t.is_alive():
                t.join()
        if self.mp_manager:
            self.mp_manager.shutdown()
            self.mp_manager = None
Ejemplo n.º 2
0
class PGHoard:
    def __init__(self, config_path):
        self.stats = None
        self.log = logging.getLogger("pghoard")
        self.log_level = None
        self.running = True
        self.config_path = config_path
        self.compression_queue = Queue()
        self.transfer_queue = Queue()
        self.syslog_handler = None
        self.config = {}
        self.site_transfers = {}
        self.state = {
            "backup_sites": {},
            "startup_time": datetime.datetime.utcnow().isoformat(),
        }
        self.load_config()

        if not os.path.exists(self.config["backup_location"]):
            os.makedirs(self.config["backup_location"])

        signal.signal(signal.SIGHUP, self.load_config)
        signal.signal(signal.SIGINT, self.quit)
        signal.signal(signal.SIGTERM, self.quit)
        self.time_of_last_backup = {}
        self.time_of_last_backup_check = {}
        self.basebackups = {}
        self.basebackups_callbacks = {}
        self.receivexlogs = {}
        self.compressors = []
        self.walreceivers = {}
        self.transfer_agents = []
        self.requested_basebackup_sites = set()

        self.inotify = InotifyWatcher(self.compression_queue)
        self.webserver = WebServer(self.config,
                                   self.requested_basebackup_sites,
                                   self.compression_queue, self.transfer_queue)

        for _ in range(self.config["compression"]["thread_count"]):
            compressor = CompressorThread(
                config_dict=self.config,
                compression_queue=self.compression_queue,
                transfer_queue=self.transfer_queue,
                stats=self.stats)
            self.compressors.append(compressor)

        compressor_state = {}  # shared among transfer agents
        for _ in range(self.config["transfer"]["thread_count"]):
            ta = TransferAgent(config=self.config,
                               compression_queue=self.compression_queue,
                               transfer_queue=self.transfer_queue,
                               stats=self.stats,
                               shared_state_dict=compressor_state)
            self.transfer_agents.append(ta)

        logutil.notify_systemd("READY=1")
        self.log.info("pghoard initialized, own_hostname: %r, cwd: %r",
                      socket.gethostname(), os.getcwd())

    def check_pg_versions_ok(self, site, pg_version_server, command):
        if pg_version_server is None:
            # remote pg version not available, don't create version alert in this case
            return False
        if not pg_version_server or pg_version_server <= 90200:
            self.log.error(
                "pghoard does not support versions earlier than 9.2, found: %r",
                pg_version_server)
            create_alert_file(self.config, "version_unsupported_error")
            return False
        pg_version_client = self.config["backup_sites"][site][command +
                                                              "_version"]
        if pg_version_server // 100 != pg_version_client // 100:
            self.log.error("Server version: %r does not match %s version: %r",
                           pg_version_server, self.config[command + "_path"],
                           pg_version_client)
            create_alert_file(self.config, "version_mismatch_error")
            return False
        return True

    def create_basebackup(self,
                          site,
                          connection_info,
                          basebackup_path,
                          callback_queue=None):
        connection_string, _ = replication_connection_string_and_slot_using_pgpass(
            connection_info)
        pg_version_server = self.check_pg_server_version(connection_string)
        if pg_version_server:
            self.config["backup_sites"][site]["pg_version"] = pg_version_server
        if not self.check_pg_versions_ok(site, pg_version_server,
                                         "pg_basebackup"):
            if callback_queue:
                callback_queue.put({"success": False})
            return None

        thread = PGBaseBackup(config=self.config,
                              site=site,
                              connection_info=connection_info,
                              basebackup_path=basebackup_path,
                              compression_queue=self.compression_queue,
                              transfer_queue=self.transfer_queue,
                              callback_queue=callback_queue,
                              pg_version_server=pg_version_server,
                              stats=self.stats)
        thread.start()
        self.basebackups[site] = thread

    def check_pg_server_version(self, connection_string):
        pg_version = None
        try:
            with closing(psycopg2.connect(connection_string)) as c:
                pg_version = c.server_version
        except psycopg2.OperationalError as ex:
            self.log.warning("%s (%s) connecting to DB at: %r",
                             ex.__class__.__name__, ex, connection_string)
            if "password authentication" in str(
                    ex) or "authentication failed" in str(ex):
                create_alert_file(self.config, "authentication_error")
            else:
                create_alert_file(self.config, "configuration_error")
        except Exception as ex:  # log all errors and return None; pylint: disable=broad-except
            self.log.exception("Problem in getting PG server version")
            self.stats.unexpected_exception(ex,
                                            where="check_pg_server_version")
        return pg_version

    def receivexlog_listener(self, site, connection_info, xlog_directory):
        connection_string, slot = replication_connection_string_and_slot_using_pgpass(
            connection_info)
        pg_version_server = self.check_pg_server_version(connection_string)
        if pg_version_server:
            self.config["backup_sites"][site]["pg_version"] = pg_version_server
        if not self.check_pg_versions_ok(site, pg_version_server,
                                         "pg_receivexlog"):
            return

        self.inotify.add_watch(xlog_directory)
        thread = PGReceiveXLog(config=self.config,
                               connection_string=connection_string,
                               xlog_location=xlog_directory,
                               site=site,
                               slot=slot,
                               pg_version_server=pg_version_server)
        thread.start()
        self.receivexlogs[site] = thread

    def start_walreceiver(self, site, chosen_backup_node, last_flushed_lsn):
        connection_string, slot = replication_connection_string_and_slot_using_pgpass(
            chosen_backup_node)
        pg_version_server = self.check_pg_server_version(connection_string)
        if pg_version_server:
            self.config["backup_sites"][site]["pg_version"] = pg_version_server
        if not WALReceiver:
            self.log.error(
                "Could not import WALReceiver, incorrect psycopg2 version?")
            return

        thread = WALReceiver(config=self.config,
                             connection_string=connection_string,
                             compression_queue=self.compression_queue,
                             replication_slot=slot,
                             pg_version_server=pg_version_server,
                             site=site,
                             last_flushed_lsn=last_flushed_lsn,
                             stats=self.stats)
        thread.start()
        self.walreceivers[site] = thread

    def create_backup_site_paths(self, site):
        site_path = os.path.join(self.config["backup_location"],
                                 self.config["path_prefix"], site)
        xlog_path = os.path.join(site_path, "xlog")
        basebackup_path = os.path.join(site_path, "basebackup")

        paths_to_create = [
            site_path,
            xlog_path,
            xlog_path + "_incoming",
            basebackup_path,
            basebackup_path + "_incoming",
        ]

        for path in paths_to_create:
            if not os.path.exists(path):
                os.makedirs(path)

        return xlog_path, basebackup_path

    def delete_remote_wal_before(self, wal_segment, site, pg_version):
        self.log.info(
            "Starting WAL deletion from: %r before: %r, pg_version: %r", site,
            wal_segment, pg_version)
        storage = self.site_transfers.get(site)
        valid_timeline = True
        tli, log, seg = wal.name_to_tli_log_seg(wal_segment)
        while True:
            if valid_timeline:
                # Decrement one segment if we're on a valid timeline
                if seg == 0 and log == 0:
                    break
                seg, log = wal.get_previous_wal_on_same_timeline(
                    seg, log, pg_version)
            wal_path = os.path.join(self.config["path_prefix"], site, "xlog",
                                    wal.name_for_tli_log_seg(tli, log, seg))
            self.log.debug("Deleting wal_file: %r", wal_path)
            try:
                storage.delete_key(wal_path)
                valid_timeline = True
            except FileNotFoundFromStorageError:
                if not valid_timeline or tli <= 1:
                    # if we didn't find any WALs to delete on this timeline or we're already at
                    # timeline 1 there's no need or possibility to try older timelines, break.
                    self.log.info("Could not delete wal_file: %r, returning",
                                  wal_path)
                    break
                # let's try the same segment number on a previous timeline, but flag that timeline
                # as "invalid" until we're able to delete at least one segment on it.
                valid_timeline = False
                tli -= 1
                self.log.info(
                    "Could not delete wal_file: %r, trying the same segment on a previous "
                    "timeline (%s)", wal_path,
                    wal.name_for_tli_log_seg(tli, log, seg))
            except Exception as ex:  # FIXME: don't catch all exceptions; pylint: disable=broad-except
                self.log.exception("Problem deleting: %r", wal_path)
                self.stats.unexpected_exception(
                    ex, where="delete_remote_wal_before")

    def delete_remote_basebackup(self, site, basebackup):
        storage = self.site_transfers.get(site)
        obj_key = os.path.join(self.config["path_prefix"], site, "basebackup",
                               basebackup)
        try:
            storage.delete_key(obj_key)
        except FileNotFoundFromStorageError:
            self.log.info("Tried to delete non-existent basebackup %r",
                          obj_key)
        except Exception as ex:  # FIXME: don't catch all exceptions; pylint: disable=broad-except
            self.log.exception("Problem deleting: %r", obj_key)
            self.stats.unexpected_exception(ex,
                                            where="delete_remote_basebackup")

    def get_remote_basebackups_info(self, site):
        storage = self.site_transfers.get(site)
        if not storage:
            storage_config = get_object_storage_config(self.config, site)
            storage = get_transfer(storage_config)
            self.site_transfers[site] = storage

        results = storage.list_path(
            os.path.join(self.config["path_prefix"], site, "basebackup"))
        for entry in results:
            # drop path from resulting list and convert timestamps
            entry["name"] = os.path.basename(entry["name"])
            entry["metadata"]["start-time"] = dateutil.parser.parse(
                entry["metadata"]["start-time"])

        results.sort(key=lambda entry: entry["metadata"]["start-time"])
        return results

    def check_backup_count_and_state(self, site):
        """Look up basebackups from the object store, prune any extra
        backups and return the datetime of the latest backup."""
        basebackups = self.get_remote_basebackups_info(site)
        self.log.debug("Found %r basebackups", basebackups)

        if basebackups:
            last_backup_time = basebackups[-1]["metadata"]["start-time"]
        else:
            last_backup_time = None

        allowed_basebackup_count = self.config["backup_sites"][site][
            "basebackup_count"]
        if allowed_basebackup_count is None:
            allowed_basebackup_count = len(basebackups)

        while len(basebackups) > allowed_basebackup_count:
            self.log.warning(
                "Too many basebackups: %d > %d, %r, starting to get rid of %r",
                len(basebackups), allowed_basebackup_count, basebackups,
                basebackups[0]["name"])
            basebackup_to_be_deleted = basebackups.pop(0)
            pg_version = basebackup_to_be_deleted["metadata"].get("pg-version")
            last_wal_segment_still_needed = 0
            if basebackups:
                last_wal_segment_still_needed = basebackups[0]["metadata"][
                    "start-wal-segment"]

            if last_wal_segment_still_needed:
                self.delete_remote_wal_before(last_wal_segment_still_needed,
                                              site, pg_version)
            self.delete_remote_basebackup(site,
                                          basebackup_to_be_deleted["name"])
        self.state["backup_sites"][site]["basebackups"] = basebackups

        return last_backup_time

    def set_state_defaults(self, site):
        if site not in self.state["backup_sites"]:
            self.state["backup_sites"][site] = {"basebackups": []}

    def startup_walk_for_missed_files(self):
        """Check xlog and xlog_incoming directories for files that receivexlog has received but not yet
        compressed as well as the files we have compressed but not yet uploaded and process them."""
        for site in self.config["backup_sites"]:
            compressed_xlog_path, _ = self.create_backup_site_paths(site)
            uncompressed_xlog_path = compressed_xlog_path + "_incoming"

            # Process uncompressed files (ie WAL pg_receivexlog received)
            for filename in os.listdir(uncompressed_xlog_path):
                full_path = os.path.join(uncompressed_xlog_path, filename)
                if not wal.XLOG_RE.match(
                        filename) and not wal.TIMELINE_RE.match(filename):
                    self.log.warning(
                        "Found invalid file %r from incoming xlog directory",
                        full_path)
                    continue
                compression_event = {
                    "delete_file_after_compression": True,
                    "full_path": full_path,
                    "site": site,
                    "type": "CLOSE_WRITE",
                }
                self.log.debug(
                    "Found: %r when starting up, adding to compression queue",
                    compression_event)
                self.compression_queue.put(compression_event)

            # Process compressed files (ie things we've processed but not yet uploaded)
            for filename in os.listdir(compressed_xlog_path):
                if filename.endswith(".metadata"):
                    continue  # silently ignore .metadata files, they're expected and processed below
                full_path = os.path.join(compressed_xlog_path, filename)
                metadata_path = full_path + ".metadata"
                is_xlog = wal.XLOG_RE.match(filename)
                is_timeline = wal.TIMELINE_RE.match(filename)
                if not ((is_xlog or is_timeline)
                        and os.path.exists(metadata_path)):
                    self.log.warning(
                        "Found invalid file %r from compressed xlog directory",
                        full_path)
                    continue
                with open(metadata_path, "r") as fp:
                    metadata = json.load(fp)

                transfer_event = {
                    "file_size": os.path.getsize(full_path),
                    "filetype": "xlog" if is_xlog else "timeline",
                    "local_path": full_path,
                    "metadata": metadata,
                    "site": site,
                    "type": "UPLOAD",
                }
                self.log.debug(
                    "Found: %r when starting up, adding to transfer queue",
                    transfer_event)
                self.transfer_queue.put(transfer_event)

    def start_threads_on_startup(self):
        # Startup threads
        self.inotify.start()
        self.webserver.start()
        for compressor in self.compressors:
            compressor.start()
        for ta in self.transfer_agents:
            ta.start()

    def _cleanup_inactive_receivexlogs(self, site):
        if site in self.receivexlogs:
            if not self.receivexlogs[site].running:
                if self.receivexlogs[site].is_alive():
                    self.receivexlogs[site].join()
                del self.receivexlogs[site]

    def handle_site(self, site, site_config):
        self.set_state_defaults(site)
        xlog_path, basebackup_path = self.create_backup_site_paths(site)

        if not site_config["active"]:
            return  # If a site has been marked inactive, don't bother checking anything

        self._cleanup_inactive_receivexlogs(site)

        chosen_backup_node = random.choice(site_config["nodes"])

        if site not in self.receivexlogs and site not in self.walreceivers:
            if site_config["active_backup_mode"] == "pg_receivexlog":
                self.receivexlog_listener(site, chosen_backup_node,
                                          xlog_path + "_incoming")
            elif site_config["active_backup_mode"] == "walreceiver":
                state_file_path = self.config["json_state_file_path"]
                walreceiver_state = {}
                with suppress(FileNotFoundError):
                    with open(state_file_path, "r") as fp:
                        old_state_file = json.load(fp)
                        walreceiver_state = old_state_file.get(
                            "walreceivers", {}).get(site, {})
                self.start_walreceiver(
                    site=site,
                    chosen_backup_node=chosen_backup_node,
                    last_flushed_lsn=walreceiver_state.get("last_flushed_lsn"))

        if site not in self.time_of_last_backup_check or \
                time.monotonic() - self.time_of_last_backup_check[site] > 300:
            self.time_of_last_backup[site] = self.check_backup_count_and_state(
                site)
            self.time_of_last_backup_check[site] = time.monotonic()

        # check if a basebackup is running, or if a basebackup has just completed
        if site in self.basebackups:
            try:
                result = self.basebackups_callbacks[site].get(block=False)
            except Empty:
                # previous basebackup (or its compression and upload) still in progress
                return
            if self.basebackups[site].is_alive():
                self.basebackups[site].join()
            del self.basebackups[site]
            del self.basebackups_callbacks[site]
            self.log.debug("Basebackup has finished for %r: %r", site, result)
            self.time_of_last_backup[site] = self.check_backup_count_and_state(
                site)
            self.time_of_last_backup_check[site] = time.monotonic()

        new_backup_needed = False
        if site in self.requested_basebackup_sites:
            self.log.info("Creating a new basebackup for %r due to request",
                          site)
            self.requested_basebackup_sites.discard(site)
            new_backup_needed = True
        elif site_config["basebackup_interval_hours"] is None:
            # Basebackups are disabled for this site (but they can still be requested over the API.)
            pass
        elif self.time_of_last_backup.get(site) is None:
            self.log.info(
                "Creating a new basebackup for %r because there are currently none",
                site)
            new_backup_needed = True
        else:
            delta_since_last_backup = datetime.datetime.now(
                datetime.timezone.utc) - self.time_of_last_backup[site]
            if delta_since_last_backup >= datetime.timedelta(
                    hours=site_config["basebackup_interval_hours"]):
                self.log.info(
                    "Creating a new basebackup for %r by schedule (%s from previous)",
                    site, delta_since_last_backup)
                new_backup_needed = True

        if new_backup_needed:
            self.basebackups_callbacks[site] = Queue()
            self.create_basebackup(site, chosen_backup_node, basebackup_path,
                                   self.basebackups_callbacks[site])

    def run(self):
        self.start_threads_on_startup()
        self.startup_walk_for_missed_files()
        while self.running:
            try:
                for site, site_config in self.config["backup_sites"].items():
                    self.handle_site(site, site_config)
                self.write_backup_state_to_json_file()
            except subprocess.CalledProcessError as ex:
                self.log.error("%s: %s", ex.__class__.__name__, ex)
            except Exception as ex:  # pylint: disable=broad-except
                self.log.exception("Unexpected exception in PGHoard main loop")
                self.stats.unexpected_exception(ex, where="pghoard_run")
            time.sleep(5.0)

    def write_backup_state_to_json_file(self):
        """Periodically write a JSON state file to disk"""
        start_time = time.time()
        state_file_path = self.config["json_state_file_path"]
        self.state["walreceivers"] = {
            key: {
                "latest_activity": value.latest_activity,
                "running": value.running,
                "last_flushed_lsn": value.last_flushed_lsn
            }
            for key, value in self.walreceivers.items()
        }
        self.state["pg_receivexlogs"] = {
            key: {
                "latest_activity": value.latest_activity,
                "running": value.running
            }
            for key, value in self.receivexlogs.items()
        }
        self.state["pg_basebackups"] = {
            key: {
                "latest_activity": value.latest_activity,
                "running": value.running
            }
            for key, value in self.basebackups.items()
        }
        self.state["compressors"] = [
            compressor.state for compressor in self.compressors
        ]
        self.state["transfer_agents"] = [
            ta.state for ta in self.transfer_agents
        ]
        self.state["queues"] = {
            "compression_queue": self.compression_queue.qsize(),
            "transfer_queue": self.transfer_queue.qsize(),
        }
        self.log.debug("Writing JSON state file to %r", state_file_path)
        write_json_file(state_file_path, self.state)
        self.log.debug("Wrote JSON state file to disk, took %.4fs",
                       time.time() - start_time)

    def load_config(self, _signal=None, _frame=None):  # pylint: disable=unused-argument
        self.log.debug("Loading JSON config from: %r, signal: %r",
                       self.config_path, _signal)
        try:
            new_config = config.read_json_config_file(self.config_path)
        except (InvalidConfigurationError, subprocess.CalledProcessError,
                UnicodeDecodeError) as ex:
            self.log.exception("Invalid config file %r: %s: %s",
                               self.config_path, ex.__class__.__name__, ex)
            # if we were called by a signal handler we'll ignore (and log)
            # the error and hope the user fixes the configuration before
            # restarting pghoard.
            if _signal is not None:
                return
            if isinstance(ex, InvalidConfigurationError):
                raise
            raise InvalidConfigurationError(self.config_path)

        self.config = new_config
        if self.config.get("syslog") and not self.syslog_handler:
            self.syslog_handler = logutil.set_syslog_handler(
                address=self.config.get("syslog_address", "/dev/log"),
                facility=self.config.get("syslog_facility", "local2"),
                logger=logging.getLogger(),
            )
        # NOTE: getLevelName() also converts level names to numbers
        self.log_level = logging.getLevelName(self.config["log_level"])
        try:
            logging.getLogger().setLevel(self.log_level)
        except ValueError:
            self.log.exception("Problem with log_level: %r", self.log_level)

        # statsd settings may have changed
        stats = self.config.get("statsd", {})
        self.stats = statsd.StatsClient(host=stats.get("host"),
                                        port=stats.get("port"),
                                        tags=stats.get("tags"),
                                        message_format=stats.get("format"))

        self.log.debug("Loaded config: %r from: %r", self.config,
                       self.config_path)

    def quit(self, _signal=None, _frame=None):  # pylint: disable=unused-argument
        self.log.warning("Quitting, signal: %r", _signal)
        self.running = False
        self.inotify.running = False
        all_threads = [self.webserver]
        all_threads.extend(self.basebackups.values())
        all_threads.extend(self.receivexlogs.values())
        all_threads.extend(self.walreceivers.values())
        all_threads.extend(self.compressors)
        all_threads.extend(self.transfer_agents)
        for t in all_threads:
            t.running = False
        # Write state file in the end so we get the last known state
        self.write_backup_state_to_json_file()
        for t in all_threads:
            if t.is_alive():
                t.join()
Ejemplo n.º 3
0
class TestWebServer(TestCase):
    def setUp(self):
        self.log = logging.getLogger("TestWebServer")
        self.temp_dir = tempfile.mkdtemp()
        self.compressed_xlog_path = os.path.join(self.temp_dir, "default", "compressed_xlog")
        self.basebackup_path = os.path.join(self.temp_dir, "default", "basebackup")
        self.compressed_timeline_path = os.path.join(self.temp_dir, "default", "compressed_timeline")
        self.pgdata_path = os.path.join(self.temp_dir, "pgdata")
        self.pg_xlog_dir = os.path.join(self.pgdata_path, "pg_xlog")

        self.config = {
            "backup_sites": {
                "default": {
                    "pg_xlog_directory": self.pg_xlog_dir,
                    "object_storage": {},
                },
            },
            "http_address": "127.0.0.1",
            "http_port": random.randint(1024, 32000),
            "backup_location": self.temp_dir,
        }
        self.compression_queue = Queue()
        self.transfer_queue = Queue()

        os.makedirs(self.compressed_xlog_path)
        os.makedirs(self.basebackup_path)
        os.makedirs(self.compressed_timeline_path)
        os.makedirs(self.pgdata_path)
        os.makedirs(self.pg_xlog_dir)

        self.uncompressed_foo_path = os.path.join(self.pg_xlog_dir, "00000001000000000000000C")
        with open(self.uncompressed_foo_path, "wb") as out_file:
            out_file.write(b"foo")
        self.foo_path = os.path.join(self.compressed_xlog_path, "00000001000000000000000C")
        with open(self.foo_path, "wb") as out_file:
            out_file.write(b"foo")
        with open(self.foo_path, "rb") as fp:
            lzma_open(self.foo_path + ".xz", mode="wb", preset=0).write(fp.read())

        self.webserver = WebServer(config=self.config,
                                   compression_queue=self.compression_queue,
                                   transfer_queue=self.transfer_queue)
        self.webserver.start()
        self.http_restore = HTTPRestore("localhost", self.config['http_port'], site="default", pgdata=self.pgdata_path)
        time.sleep(0.05)  # Hack to give the server time to start up

    def test_list_empty_basebackups(self):
        self.assertEqual(self.http_restore.list_basebackups(), [])  # pylint: disable=protected-access

    def test_archiving(self):
        compressor = Compressor(config=self.config,
                                compression_queue=self.compression_queue,
                                transfer_queue=self.transfer_queue)
        compressor.start()
        xlog_file = "00000001000000000000000C"
        self.assertTrue(archive(port=self.config['http_port'], site="default", xlog_file=xlog_file))
        self.assertTrue(os.path.exists(os.path.join(self.compressed_xlog_path, xlog_file)))
        self.log.error(os.path.join(self.compressed_xlog_path, xlog_file))
        compressor.running = False

    def test_archiving_backup_label_from_archive_command(self):
        compressor = Compressor(config=self.config,
                                compression_queue=self.compression_queue,
                                transfer_queue=self.transfer_queue)
        compressor.start()

        xlog_file = "000000010000000000000002.00000028.backup"
        xlog_path = os.path.join(self.pg_xlog_dir, xlog_file)
        with open(xlog_path, "w") as fp:
            fp.write("jee")
        self.assertTrue(archive(port=self.config['http_port'], site="default", xlog_file=xlog_file))
        self.assertFalse(os.path.exists(os.path.join(self.compressed_xlog_path, xlog_file)))
        compressor.running = False

#    def test_get_basebackup_file(self):
#        self.http_restore.get_basebackup_file()

    def test_get_archived_file(self):
        xlog_file = "00000001000000000000000F"
        filepath = os.path.join(self.compressed_xlog_path, xlog_file)
        lzma_open(filepath + ".xz", mode="wb", preset=0).write(b"jee")
        self.http_restore.get_archive_file(xlog_file, "pg_xlog/" + xlog_file, path_prefix=self.pgdata_path)
        self.assertTrue(os.path.exists(os.path.join(self.pg_xlog_dir, xlog_file)))

    def tearDown(self):
        self.webserver.close()
        shutil.rmtree(self.temp_dir)
Ejemplo n.º 4
0
class PGHoard(object):
    def __init__(self, config_path):
        self.log = logging.getLogger("pghoard")
        self.log_level = None
        self.running = True
        self.config_path = config_path
        self.compression_queue = Queue()
        self.transfer_queue = Queue()
        self.syslog_handler = None
        self.config = {}
        self.site_transfers = {}
        self.state = {
            "backup_sites": {},
            "startup_time": datetime.datetime.utcnow().isoformat(),
            }
        self.load_config()

        if not os.path.exists(self.config["backup_location"]):
            os.makedirs(self.config["backup_location"])

        signal.signal(signal.SIGHUP, self.load_config)
        signal.signal(signal.SIGINT, self.quit)
        signal.signal(signal.SIGTERM, self.quit)
        self.time_of_last_backup = {}
        self.time_of_last_backup_check = {}
        self.basebackups = {}
        self.basebackups_callbacks = {}
        self.receivexlogs = {}
        self.compressors = []
        self.transfer_agents = []
        self.requested_basebackup_sites = {}

        self.inotify = InotifyWatcher(self.compression_queue)
        self.webserver = WebServer(
            self.config,
            self.requested_basebackup_sites,
            self.compression_queue,
            self.transfer_queue)

        for _ in range(self.config["compression"]["thread_count"]):
            compressor = CompressorThread(self.config, self.compression_queue, self.transfer_queue)
            self.compressors.append(compressor)

        for _ in range(self.config["transfer"]["thread_count"]):
            ta = TransferAgent(self.config, self.compression_queue, self.transfer_queue)
            self.transfer_agents.append(ta)

        if daemon:  # If we can import systemd we always notify it
            daemon.notify("READY=1")
            self.log.info("Sent startup notification to systemd that pghoard is READY")
        self.log.info("pghoard initialized, own_hostname: %r, cwd: %r", socket.gethostname(), os.getcwd())

    def check_pg_versions_ok(self, pg_version_server, command):
        if not pg_version_server or pg_version_server <= 90300:
            self.log.error("pghoard does not support versions earlier than 9.3, found: %r", pg_version_server)
            self.create_alert_file("version_unsupported_error")
            return False
        command_path = self.config.get(command + "_path", "/usr/bin/" + command)
        output = os.popen(command_path + " --version").read().strip()
        pg_version_client = convert_pg_command_version_to_number(output)
        if pg_version_server // 100 != pg_version_client // 100:
            self.log.error("Server version: %r does not match %s version: %r",
                           pg_version_server, command_path, pg_version_client)
            self.create_alert_file("version_mismatch_error")
            return False
        return True

    def create_basebackup(self, site, connection_string, basebackup_path, callback_queue=None):
        pg_version_server = self.check_pg_server_version(connection_string)
        if not self.check_pg_versions_ok(pg_version_server, "pg_basebackup"):
            if callback_queue:
                callback_queue.put({"success": False})
            return None

        # Note that this xlog file value will only be correct if no other basebackups are run
        # in parallel. PGHoard itself will never do this itself but if the user starts
        # one on his own, and if tablespaces are set to False we'll get an incorrect
        # start-wal-time since the pg_basebackup from pghoard will not generate a
        # new checkpoint. This means that this xlog information would not be the oldest
        # required to restore from this basebackup.
        current_xlog = wal.get_current_wal_from_identify_system(connection_string)

        thread = PGBaseBackup(
            config=self.config,
            site=site,
            connection_string=connection_string,
            basebackup_path=basebackup_path,
            compression_queue=self.compression_queue,
            transfer_queue=self.transfer_queue,
            callback_queue=callback_queue,
            start_wal_segment=current_xlog)
        thread.start()
        self.basebackups[site] = thread

    def check_pg_server_version(self, connection_string):
        pg_version = None
        try:
            with closing(psycopg2.connect(connection_string)) as c:
                pg_version = c.server_version
        except psycopg2.OperationalError as ex:
            self.log.warning("%s (%s) connecting to DB at: %r",
                             ex.__class__.__name__, ex, connection_string)
            if "password authentication" in str(ex) or "authentication failed" in str(ex):
                self.create_alert_file("authentication_error")
            else:
                self.create_alert_file("configuration_error")
        except Exception:  # log all errors and return None; pylint: disable=broad-except
            self.log.exception("Problem in getting PG server version")
        return pg_version

    def receivexlog_listener(self, cluster, xlog_location, connection_string, slot):
        pg_version_server = self.check_pg_server_version(connection_string)
        if not self.check_pg_versions_ok(pg_version_server, "pg_receivexlog"):
            return
        command = [
            self.config.get("pg_receivexlog_path", "/usr/bin/pg_receivexlog"),
            "--dbname", connection_string,
            "--status-interval", "1",
            "--verbose",
            "--directory", xlog_location,
        ]

        if pg_version_server >= 90400 and slot:
            command.extend(["--slot", slot])

        self.inotify.add_watch(xlog_location)
        thread = PGReceiveXLog(command)
        thread.start()
        self.receivexlogs[cluster] = thread

    def create_backup_site_paths(self, site):
        site_path = os.path.join(self.config["backup_location"], self.config.get("path_prefix", ""), site)
        xlog_path = os.path.join(site_path, "xlog")
        basebackup_path = os.path.join(site_path, "basebackup")

        paths_to_create = [
            site_path,
            xlog_path,
            xlog_path + "_incoming",
            basebackup_path,
            basebackup_path + "_incoming",
        ]

        for path in paths_to_create:
            if not os.path.exists(path):
                os.makedirs(path)

        return xlog_path, basebackup_path

    def delete_remote_wal_before(self, wal_segment, site):
        self.log.debug("Starting WAL deletion from: %r before: %r", site, wal_segment)
        storage = self.site_transfers.get(site)
        valid_timeline = True
        tli, log, seg = wal.name_to_tli_log_seg(wal_segment)
        while True:
            if valid_timeline:
                # Decrement one segment if we're on a valid timeline
                if seg == 0 and log == 0:
                    break
                seg, log = wal.get_previous_wal_on_same_timeline(seg, log)

            wal_path = os.path.join(self.config.get("path_prefix", ""), site, "xlog",
                                    wal.name_for_tli_log_seg(tli, log, seg))
            self.log.debug("Deleting wal_file: %r", wal_path)
            try:
                storage.delete_key(wal_path)
                valid_timeline = True
            except FileNotFoundFromStorageError:
                if not valid_timeline or tli <= 1:
                    # if we didn't find any WALs to delete on this timeline or we're already at
                    # timeline 1 there's no need or possibility to try older timelines, break.
                    self.log.info("Could not delete wal_file: %r, returning", wal_path)
                    break
                # let's try the same segment number on a previous timeline, but flag that timeline
                # as "invalid" until we're able to delete at least one segment on it.
                valid_timeline = False
                tli -= 1
                self.log.info("Could not delete wal_file: %r, trying the same segment on a previous "
                              "timeline (%s)", wal_path, wal.name_for_tli_log_seg(tli, log, seg))
            except:  # FIXME: don't catch all exceptions; pylint: disable=bare-except
                self.log.exception("Problem deleting: %r", wal_path)

    def delete_remote_basebackup(self, site, basebackup):
        storage = self.site_transfers.get(site)
        obj_key = os.path.join(self.config.get("path_prefix", ""),
                               site,
                               "basebackup",
                               basebackup)
        try:
            storage.delete_key(obj_key)
        except FileNotFoundFromStorageError:
            self.log.info("Tried to delete non-existent basebackup %r", obj_key)
        except:  # FIXME: don't catch all exceptions; pylint: disable=bare-except
            self.log.exception("Problem deleting: %r", obj_key)

    def get_remote_basebackups_info(self, site):
        storage = self.site_transfers.get(site)
        if not storage:
            storage_type, storage_config = get_object_storage_config(self.config, site)
            storage = get_transfer(storage_type, storage_config)
            self.site_transfers[site] = storage

        results = storage.list_path(os.path.join(self.config.get("path_prefix", ""),
                                                 site,
                                                 "basebackup"))
        for entry in results:
            # drop path from resulting list and convert timestamps
            entry["name"] = os.path.basename(entry["name"])
            entry["last_modified"] = entry["last_modified"].timestamp()

        results.sort(key=lambda entry: entry["metadata"]["start-time"])
        return results

    def check_backup_count_and_state(self, site):
        allowed_basebackup_count = self.config['backup_sites'][site]['basebackup_count']
        basebackups = self.get_remote_basebackups_info(site)
        self.log.debug("Found %r basebackups", basebackups)

        # Needs to be the m_time of the newest basebackup
        m_time = basebackups[-1]["last_modified"] if basebackups else 0

        while len(basebackups) > allowed_basebackup_count:
            self.log.warning("Too many basebackups: %d > %d, %r, starting to get rid of %r",
                             len(basebackups), allowed_basebackup_count, basebackups, basebackups[0]["name"])
            basebackup_to_be_deleted = basebackups.pop(0)

            last_wal_segment_still_needed = 0
            if basebackups:
                last_wal_segment_still_needed = basebackups[0]["metadata"]["start-wal-segment"]

            if last_wal_segment_still_needed:
                self.delete_remote_wal_before(last_wal_segment_still_needed, site)
            self.delete_remote_basebackup(site, basebackup_to_be_deleted["name"])
        self.state["backup_sites"][site]['basebackups'] = basebackups
        return m_time

    def set_state_defaults(self, site):
        if site not in self.state["backup_sites"]:
            self.state['backup_sites'][site] = {"basebackups": []}

    def startup_walk_for_missed_files(self):
        for site in self.config["backup_sites"]:
            xlog_path, basebackup_path = self.create_backup_site_paths(site)  # pylint: disable=unused-variable
            for filename in os.listdir(xlog_path):
                if not filename.endswith(".partial"):
                    compression_event = {
                        "delete_file_after_compression": True,
                        "full_path": os.path.join(xlog_path, filename),
                        "site": site,
                        "type": "CLOSE_WRITE",
                    }
                    self.log.debug("Found: %r when starting up, adding to compression queue", compression_event)
                    self.compression_queue.put(compression_event)

    def start_threads_on_startup(self):
        # Startup threads
        self.inotify.start()
        self.webserver.start()
        for compressor in self.compressors:
            compressor.start()
        for ta in self.transfer_agents:
            ta.start()

    def handle_site(self, site, site_config):
        self.set_state_defaults(site)
        xlog_path, basebackup_path = self.create_backup_site_paths(site)

        if not site_config.get("active", True):
            #  If a site has been marked inactive, don't bother checking anything
            return

        chosen_backup_node = random.choice(site_config["nodes"])

        if site not in self.receivexlogs and site_config.get("active_backup_mode") == "pg_receivexlog":
            connection_string, slot = replication_connection_string_using_pgpass(chosen_backup_node)
            self.receivexlog_listener(site, xlog_path + "_incoming", connection_string, slot)

        if site not in self.time_of_last_backup_check or \
                time.monotonic() - self.time_of_last_backup_check[site] > 60:
            self.time_of_last_backup[site] = self.check_backup_count_and_state(site)
            self.time_of_last_backup_check[site] = time.monotonic()

        # check if a basebackup is running, or if a basebackup has just completed
        if site in self.basebackups:
            try:
                result = self.basebackups_callbacks[site].get(block=False)
            except Empty:
                # previous basebackup (or its compression and upload) still in progress
                return
            if self.basebackups[site].is_alive():
                self.basebackups[site].join()
            del self.basebackups[site]
            del self.basebackups_callbacks[site]
            self.log.debug("Basebackup has finished for %r: %r", site, result)
            self.time_of_last_backup[site] = self.check_backup_count_and_state(site)
            self.time_of_last_backup_check[site] = time.monotonic()

        time_since_last_backup = time.time() - self.time_of_last_backup[site]
        if time_since_last_backup > site_config["basebackup_interval_hours"] * 3600 or \
           self.requested_basebackup_sites.pop(site, False):
            self.log.debug("Starting to create a new basebackup for: %r since time from previous: %r",
                           site, time_since_last_backup)
            connection_string, slot = replication_connection_string_using_pgpass(chosen_backup_node)
            self.basebackups_callbacks[site] = Queue()
            self.create_basebackup(site, connection_string, basebackup_path, self.basebackups_callbacks[site])

    def run(self):
        self.start_threads_on_startup()
        self.startup_walk_for_missed_files()
        while self.running:
            try:
                for site, site_config in self.config['backup_sites'].items():
                    self.handle_site(site, site_config)
                self.write_backup_state_to_json_file()
            except:  # pylint: disable=bare-except
                self.log.exception("Problem in PGHoard main loop")
            time.sleep(5.0)

    def write_backup_state_to_json_file(self):
        """Periodically write a JSON state file to disk"""
        start_time = time.time()
        state_file_path = self.config.get("json_state_file_path", "/tmp/pghoard_state.json")
        self.state["pg_receivexlogs"] = {
            key: {"latest_activity": value.latest_activity.isoformat(), "running": value.running}
            for key, value in self.receivexlogs.items()
        }
        self.state["pg_basebackups"] = {
            key: {"latest_activity": value.latest_activity.isoformat(), "running": value.running}
            for key, value in self.basebackups.items()
        }
        self.state["compressors"] = [compressor.state for compressor in self.compressors]
        self.state["transfer_agents"] = [ta.state for ta in self.transfer_agents]
        self.state["queues"] = {
            "compression_queue": self.compression_queue.qsize(),
            "transfer_queue": self.transfer_queue.qsize(),
            }
        json_to_dump = json.dumps(self.state, indent=4, sort_keys=True)
        self.log.debug("Writing JSON state file to: %r, file_size: %r", state_file_path, len(json_to_dump))
        with open(state_file_path + ".tmp", "w") as fp:
            fp.write(json_to_dump)
        os.rename(state_file_path + ".tmp", state_file_path)
        self.log.debug("Wrote JSON state file to disk, took %.4fs", time.time() - start_time)

    def create_alert_file(self, filename):
        filepath = os.path.join(self.config.get("alert_file_dir", os.getcwd()), filename)
        self.log.debug("Creating alert file: %r", filepath)
        with open(filepath, "w") as fp:
            fp.write("alert")

    def delete_alert_file(self, filename):
        filepath = os.path.join(self.config.get("alert_file_dir", os.getcwd()), filename)
        if os.path.exists(filepath):
            self.log.debug("Deleting alert file: %r", filepath)
            os.unlink(filepath)

    def load_config(self, _signal=None, _frame=None):
        self.log.debug("Loading JSON config from: %r, signal: %r, frame: %r",
                       self.config_path, _signal, _frame)
        try:
            with open(self.config_path, "r") as fp:
                self.config = json.load(fp)
        except (IOError, ValueError) as ex:
            self.log.exception("Invalid JSON config %r: %s", self.config_path, ex)
            # if we were called by a signal handler we'll ignore (and log)
            # the error and hope the user fixes the configuration before
            # restarting pghoard.
            if _signal is not None:
                return
            raise InvalidConfigurationError(self.config_path)

        # default to 5 compression and transfer threads
        self.config.setdefault("compression", {}).setdefault("thread_count", 5)
        self.config.setdefault("transfer", {}).setdefault("thread_count", 5)
        # default to prefetching min(#compressors, #transferagents) - 1 objects so all
        # operations where prefetching is used run fully in parallel without waiting to start
        self.config.setdefault("restore_prefetch", min(
            self.config["compression"]["thread_count"],
            self.config["transfer"]["thread_count"]) - 1)

        if self.config.get("syslog") and not self.syslog_handler:
            self.syslog_handler = set_syslog_handler(self.config.get("syslog_address", "/dev/log"),
                                                     self.config.get("syslog_facility", "local2"),
                                                     self.log)
        self.log_level = getattr(logging, self.config.get("log_level", "DEBUG"))
        try:
            logging.getLogger().setLevel(self.log_level)
        except ValueError:
            self.log.exception("Problem with log_level: %r", self.log_level)
        # we need the failover_command to be converted into subprocess [] format
        self.log.debug("Loaded config: %r from: %r", self.config, self.config_path)

    def quit(self, _signal=None, _frame=None):
        self.log.warning("Quitting, signal: %r, frame: %r", _signal, _frame)
        self.running = False
        self.inotify.running = False
        all_threads = [self.webserver]
        all_threads.extend(self.basebackups.values())
        all_threads.extend(self.receivexlogs.values())
        all_threads.extend(self.compressors)
        all_threads.extend(self.transfer_agents)
        for t in all_threads:
            t.running = False
        for t in all_threads:
            if t.is_alive():
                t.join()
Ejemplo n.º 5
0
class PGHoard:
    def __init__(self, config_path):
        self.metrics = None
        self.log = logging.getLogger("pghoard")
        self.log_level = None
        self.running = True
        self.config_path = config_path
        self.compression_queue = Queue()
        self.transfer_queue = Queue()
        self.syslog_handler = None
        self.basebackups = {}
        self.basebackups_callbacks = {}
        self.receivexlogs = {}
        self.compressors = []
        self.walreceivers = {}
        self.transfer_agents = []
        self.config = {}
        self.mp_manager = None
        self.site_transfers = {}
        self.state = {
            "backup_sites": {},
            "startup_time": datetime.datetime.utcnow().isoformat(),
        }
        self.load_config()
        if self.config["transfer"]["thread_count"] > 1:
            self.mp_manager = multiprocessing.Manager()

        if not os.path.exists(self.config["backup_location"]):
            os.makedirs(self.config["backup_location"])

        signal.signal(signal.SIGHUP, self.load_config)
        signal.signal(signal.SIGINT, self.quit)
        signal.signal(signal.SIGTERM, self.quit)
        self.time_of_last_backup = {}
        self.time_of_last_backup_check = {}
        self.requested_basebackup_sites = set()

        self.inotify = InotifyWatcher(self.compression_queue)
        self.webserver = WebServer(
            self.config,
            self.requested_basebackup_sites,
            self.compression_queue,
            self.transfer_queue,
            self.metrics)

        for _ in range(self.config["compression"]["thread_count"]):
            compressor = CompressorThread(
                config_dict=self.config,
                compression_queue=self.compression_queue,
                transfer_queue=self.transfer_queue,
                metrics=self.metrics)
            self.compressors.append(compressor)

        compressor_state = {}  # shared among transfer agents
        for _ in range(self.config["transfer"]["thread_count"]):
            ta = TransferAgent(
                config=self.config,
                compression_queue=self.compression_queue,
                mp_manager=self.mp_manager,
                transfer_queue=self.transfer_queue,
                metrics=self.metrics,
                shared_state_dict=compressor_state)
            self.transfer_agents.append(ta)

        logutil.notify_systemd("READY=1")
        self.log.info("pghoard initialized, own_hostname: %r, cwd: %r", socket.gethostname(), os.getcwd())

    def check_pg_versions_ok(self, site, pg_version_server, command):
        if pg_version_server is None:
            # remote pg version not available, don't create version alert in this case
            return False
        if not pg_version_server:
            self.log.error("pghoard does not support versions earlier than 9.3, found: %r", pg_version_server)
            create_alert_file(self.config, "version_unsupported_error")
            return False
        pg_version_client = self.config["backup_sites"][site][command + "_version"]
        if pg_version_server // 100 != pg_version_client // 100:
            self.log.error("Server version: %r does not match %s version: %r",
                           pg_version_server, self.config[command + "_path"], pg_version_client)
            create_alert_file(self.config, "version_mismatch_error")
            return False
        return True

    def create_basebackup(self, site, connection_info, basebackup_path, callback_queue=None):
        connection_string, _ = replication_connection_string_and_slot_using_pgpass(connection_info)
        pg_version_server = self.check_pg_server_version(connection_string, site)
        if not self.check_pg_versions_ok(site, pg_version_server, "pg_basebackup"):
            if callback_queue:
                callback_queue.put({"success": False})
            return

        thread = PGBaseBackup(
            config=self.config,
            site=site,
            connection_info=connection_info,
            basebackup_path=basebackup_path,
            compression_queue=self.compression_queue,
            transfer_queue=self.transfer_queue,
            callback_queue=callback_queue,
            pg_version_server=pg_version_server,
            metrics=self.metrics)
        thread.start()
        self.basebackups[site] = thread

    def check_pg_server_version(self, connection_string, site):
        if "pg_version" in self.config["backup_sites"][site]:
            return self.config["backup_sites"][site]["pg_version"]

        pg_version = None
        try:
            with closing(psycopg2.connect(connection_string)) as c:
                pg_version = c.server_version  # pylint: disable=no-member
                # Cache pg_version so we don't have to query it again, note that this means that for major
                # version upgrades you want to restart pghoard.
                self.config["backup_sites"][site]["pg_version"] = pg_version
        except psycopg2.OperationalError as ex:
            self.log.warning("%s (%s) connecting to DB at: %r",
                             ex.__class__.__name__, ex, connection_string)
            if "password authentication" in str(ex) or "authentication failed" in str(ex):
                create_alert_file(self.config, "authentication_error")
            else:
                create_alert_file(self.config, "configuration_error")
        except Exception as ex:  # log all errors and return None; pylint: disable=broad-except
            self.log.exception("Problem in getting PG server version")
            self.metrics.unexpected_exception(ex, where="check_pg_server_version")
        return pg_version

    def receivexlog_listener(self, site, connection_info, wal_directory):
        connection_string, slot = replication_connection_string_and_slot_using_pgpass(connection_info)
        pg_version_server = self.check_pg_server_version(connection_string, site)
        if not self.check_pg_versions_ok(site, pg_version_server, "pg_receivexlog"):
            return

        self.inotify.add_watch(wal_directory)
        thread = PGReceiveXLog(
            config=self.config,
            connection_string=connection_string,
            wal_location=wal_directory,
            site=site,
            slot=slot,
            pg_version_server=pg_version_server)
        thread.start()
        self.receivexlogs[site] = thread

    def start_walreceiver(self, site, chosen_backup_node, last_flushed_lsn):
        connection_string, slot = replication_connection_string_and_slot_using_pgpass(chosen_backup_node)
        pg_version_server = self.check_pg_server_version(connection_string, site)
        if not WALReceiver:
            self.log.error("Could not import WALReceiver, incorrect psycopg2 version?")
            return

        thread = WALReceiver(
            config=self.config,
            connection_string=connection_string,
            compression_queue=self.compression_queue,
            replication_slot=slot,
            pg_version_server=pg_version_server,
            site=site,
            last_flushed_lsn=last_flushed_lsn,
            metrics=self.metrics)
        thread.start()
        self.walreceivers[site] = thread

    def create_backup_site_paths(self, site):
        site_path = os.path.join(self.config["backup_location"], self.config["backup_sites"][site]["prefix"])
        xlog_path = os.path.join(site_path, "xlog")
        basebackup_path = os.path.join(site_path, "basebackup")

        paths_to_create = [
            site_path,
            xlog_path,
            xlog_path + "_incoming",
            basebackup_path,
            basebackup_path + "_incoming",
        ]

        for path in paths_to_create:
            if not os.path.exists(path):
                os.makedirs(path)

        return xlog_path, basebackup_path

    def delete_remote_wal_before(self, wal_segment, site, pg_version):
        self.log.info("Starting WAL deletion from: %r before: %r, pg_version: %r",
                      site, wal_segment, pg_version)
        storage = self.site_transfers.get(site)
        valid_timeline = True
        tli, log, seg = wal.name_to_tli_log_seg(wal_segment)
        while True:
            if valid_timeline:
                # Decrement one segment if we're on a valid timeline
                if seg == 0 and log == 0:
                    break
                seg, log = wal.get_previous_wal_on_same_timeline(seg, log, pg_version)
            wal_path = os.path.join(self.config["backup_sites"][site]["prefix"], "xlog",
                                    wal.name_for_tli_log_seg(tli, log, seg))
            self.log.debug("Deleting wal_file: %r", wal_path)
            try:
                storage.delete_key(wal_path)
                valid_timeline = True
            except FileNotFoundFromStorageError:
                if not valid_timeline or tli <= 1:
                    # if we didn't find any WALs to delete on this timeline or we're already at
                    # timeline 1 there's no need or possibility to try older timelines, break.
                    self.log.info("Could not delete wal_file: %r, returning", wal_path)
                    break
                # let's try the same segment number on a previous timeline, but flag that timeline
                # as "invalid" until we're able to delete at least one segment on it.
                valid_timeline = False
                tli -= 1
                self.log.info("Could not delete wal_file: %r, trying the same segment on a previous "
                              "timeline (%s)", wal_path, wal.name_for_tli_log_seg(tli, log, seg))
            except Exception as ex:  # FIXME: don't catch all exceptions; pylint: disable=broad-except
                self.log.exception("Problem deleting: %r", wal_path)
                self.metrics.unexpected_exception(ex, where="delete_remote_wal_before")

    def delete_remote_basebackup(self, site, basebackup, metadata):
        start_time = time.monotonic()
        storage = self.site_transfers.get(site)
        main_backup_key = os.path.join(self.config["backup_sites"][site]["prefix"], "basebackup", basebackup)
        basebackup_data_files = [main_backup_key]

        if metadata.get("format") == "pghoard-bb-v2":
            bmeta_compressed = storage.get_contents_to_string(main_backup_key)[0]
            with rohmufile.file_reader(fileobj=io.BytesIO(bmeta_compressed), metadata=metadata,
                                       key_lookup=config.key_lookup_for_site(self.config, site)) as input_obj:
                bmeta = extract_pghoard_bb_v2_metadata(input_obj)
                self.log.debug("PGHoard chunk metadata: %r", bmeta)
                for chunk in bmeta["chunks"]:
                    basebackup_data_files.append(os.path.join(
                        self.config["backup_sites"][site]["prefix"],
                        "basebackup_chunk",
                        chunk["chunk_filename"],
                    ))

        self.log.debug("Deleting basebackup datafiles: %r", ', '.join(basebackup_data_files))
        for obj_key in basebackup_data_files:
            try:
                storage.delete_key(obj_key)
            except FileNotFoundFromStorageError:
                self.log.info("Tried to delete non-existent basebackup %r", obj_key)
            except Exception as ex:  # FIXME: don't catch all exceptions; pylint: disable=broad-except
                self.log.exception("Problem deleting: %r", obj_key)
                self.metrics.unexpected_exception(ex, where="delete_remote_basebackup")
        self.log.info("Deleted basebackup datafiles: %r, took: %.2fs",
                      ', '.join(basebackup_data_files), time.monotonic() - start_time)

    def get_remote_basebackups_info(self, site):
        storage = self.site_transfers.get(site)
        if not storage:
            storage_config = get_object_storage_config(self.config, site)
            storage = get_transfer(storage_config)
            self.site_transfers[site] = storage

        results = storage.list_path(os.path.join(self.config["backup_sites"][site]["prefix"], "basebackup"))
        for entry in results:
            # drop path from resulting list and convert timestamps
            entry["name"] = os.path.basename(entry["name"])
            entry["metadata"]["start-time"] = dates.parse_timestamp(entry["metadata"]["start-time"])

        results.sort(key=lambda entry: entry["metadata"]["start-time"])
        return results

    def check_backup_count_and_state(self, site):
        """Look up basebackups from the object store, prune any extra
        backups and return the datetime of the latest backup."""
        basebackups = self.get_remote_basebackups_info(site)
        self.log.debug("Found %r basebackups", basebackups)

        if basebackups:
            last_backup_time = basebackups[-1]["metadata"]["start-time"]
        else:
            last_backup_time = None

        allowed_basebackup_count = self.config["backup_sites"][site]["basebackup_count"]
        if allowed_basebackup_count is None:
            allowed_basebackup_count = len(basebackups)

        while len(basebackups) > allowed_basebackup_count:
            self.log.warning("Too many basebackups: %d > %d, %r, starting to get rid of %r",
                             len(basebackups), allowed_basebackup_count, basebackups, basebackups[0]["name"])
            basebackup_to_be_deleted = basebackups.pop(0)
            pg_version = basebackup_to_be_deleted["metadata"].get("pg-version")
            last_wal_segment_still_needed = 0
            if basebackups:
                last_wal_segment_still_needed = basebackups[0]["metadata"]["start-wal-segment"]

            if last_wal_segment_still_needed:
                self.delete_remote_wal_before(last_wal_segment_still_needed, site, pg_version)
            self.delete_remote_basebackup(site, basebackup_to_be_deleted["name"], basebackup_to_be_deleted["metadata"])
        self.state["backup_sites"][site]["basebackups"] = basebackups

        return last_backup_time

    def set_state_defaults(self, site):
        if site not in self.state["backup_sites"]:
            self.state["backup_sites"][site] = {"basebackups": []}

    def startup_walk_for_missed_files(self):
        """Check xlog and xlog_incoming directories for files that receivexlog has received but not yet
        compressed as well as the files we have compressed but not yet uploaded and process them."""
        for site in self.config["backup_sites"]:
            compressed_xlog_path, _ = self.create_backup_site_paths(site)
            uncompressed_xlog_path = compressed_xlog_path + "_incoming"

            # Process uncompressed files (ie WAL pg_receivexlog received)
            for filename in os.listdir(uncompressed_xlog_path):
                full_path = os.path.join(uncompressed_xlog_path, filename)
                if not wal.WAL_RE.match(filename) and not wal.TIMELINE_RE.match(filename):
                    self.log.warning("Found invalid file %r from incoming xlog directory", full_path)
                    continue
                compression_event = {
                    "delete_file_after_compression": True,
                    "full_path": full_path,
                    "site": site,
                    "src_path": "{}.partial",
                    "type": "MOVE",
                }
                self.log.debug("Found: %r when starting up, adding to compression queue", compression_event)
                self.compression_queue.put(compression_event)

            # Process compressed files (ie things we've processed but not yet uploaded)
            for filename in os.listdir(compressed_xlog_path):
                if filename.endswith(".metadata"):
                    continue  # silently ignore .metadata files, they're expected and processed below
                full_path = os.path.join(compressed_xlog_path, filename)
                metadata_path = full_path + ".metadata"
                is_xlog = wal.WAL_RE.match(filename)
                is_timeline = wal.TIMELINE_RE.match(filename)
                if not ((is_xlog or is_timeline) and os.path.exists(metadata_path)):
                    self.log.warning("Found invalid file %r from compressed xlog directory", full_path)
                    continue
                with open(metadata_path, "r") as fp:
                    metadata = json.load(fp)

                transfer_event = {
                    "file_size": os.path.getsize(full_path),
                    "filetype": "xlog" if is_xlog else "timeline",
                    "local_path": full_path,
                    "metadata": metadata,
                    "site": site,
                    "type": "UPLOAD",
                }
                self.log.debug("Found: %r when starting up, adding to transfer queue", transfer_event)
                self.transfer_queue.put(transfer_event)

    def start_threads_on_startup(self):
        # Startup threads
        self.inotify.start()
        self.webserver.start()
        for compressor in self.compressors:
            compressor.start()
        for ta in self.transfer_agents:
            ta.start()

    def _cleanup_inactive_receivexlogs(self, site):
        if site in self.receivexlogs:
            if not self.receivexlogs[site].running:
                if self.receivexlogs[site].is_alive():
                    self.receivexlogs[site].join()
                del self.receivexlogs[site]

    def handle_site(self, site, site_config):
        self.set_state_defaults(site)
        xlog_path, basebackup_path = self.create_backup_site_paths(site)

        if not site_config["active"]:
            return  # If a site has been marked inactive, don't bother checking anything

        self._cleanup_inactive_receivexlogs(site)

        chosen_backup_node = random.choice(site_config["nodes"])

        if site not in self.receivexlogs and site not in self.walreceivers:
            if site_config["active_backup_mode"] == "pg_receivexlog":
                self.receivexlog_listener(site, chosen_backup_node, xlog_path + "_incoming")
            elif site_config["active_backup_mode"] == "walreceiver":
                state_file_path = self.config["json_state_file_path"]
                walreceiver_state = {}
                with suppress(FileNotFoundError):
                    with open(state_file_path, "r") as fp:
                        old_state_file = json.load(fp)
                        walreceiver_state = old_state_file.get("walreceivers", {}).get(site, {})
                self.start_walreceiver(
                    site=site,
                    chosen_backup_node=chosen_backup_node,
                    last_flushed_lsn=walreceiver_state.get("last_flushed_lsn"))

        last_check_time = self.time_of_last_backup_check.get(site)
        if not last_check_time or (time.monotonic() - self.time_of_last_backup_check[site]) > 300:
            self.time_of_last_backup[site] = self.check_backup_count_and_state(site)
            self.time_of_last_backup_check[site] = time.monotonic()

        # check if a basebackup is running, or if a basebackup has just completed
        if site in self.basebackups:
            try:
                result = self.basebackups_callbacks[site].get(block=False)
            except Empty:
                # previous basebackup (or its compression and upload) still in progress
                return
            if self.basebackups[site].is_alive():
                self.basebackups[site].join()
            del self.basebackups[site]
            del self.basebackups_callbacks[site]
            self.log.debug("Basebackup has finished for %r: %r", site, result)
            self.time_of_last_backup[site] = self.check_backup_count_and_state(site)
            self.time_of_last_backup_check[site] = time.monotonic()

        new_backup_needed = False
        if site in self.requested_basebackup_sites:
            self.log.info("Creating a new basebackup for %r due to request", site)
            self.requested_basebackup_sites.discard(site)
            new_backup_needed = True
        elif site_config["basebackup_interval_hours"] is None:
            # Basebackups are disabled for this site (but they can still be requested over the API.)
            pass
        elif self.time_of_last_backup.get(site) is None:
            self.log.info("Creating a new basebackup for %r because there are currently none", site)
            new_backup_needed = True
        else:
            delta_since_last_backup = datetime.datetime.now(datetime.timezone.utc) - self.time_of_last_backup[site]
            if delta_since_last_backup >= datetime.timedelta(hours=site_config["basebackup_interval_hours"]):
                self.log.info("Creating a new basebackup for %r by schedule (%s from previous)",
                              site, delta_since_last_backup)
                new_backup_needed = True

        if new_backup_needed and not os.path.exists(self.config["maintenance_mode_file"]):
            self.basebackups_callbacks[site] = Queue()
            self.create_basebackup(site, chosen_backup_node, basebackup_path, self.basebackups_callbacks[site])

    def run(self):
        self.start_threads_on_startup()
        self.startup_walk_for_missed_files()
        while self.running:
            try:
                for site, site_config in self.config["backup_sites"].items():
                    self.handle_site(site, site_config)
                self.write_backup_state_to_json_file()
            except subprocess.CalledProcessError as ex:
                self.log.error("main loop: %s: %s, retrying...", ex.__class__.__name__, ex)
            except Exception as ex:  # pylint: disable=broad-except
                self.log.exception("Unexpected exception in PGHoard main loop")
                self.metrics.unexpected_exception(ex, where="pghoard_run")
            time.sleep(5.0)

    def write_backup_state_to_json_file(self):
        """Periodically write a JSON state file to disk"""
        start_time = time.time()
        state_file_path = self.config["json_state_file_path"]
        self.state["walreceivers"] = {
            key: {"latest_activity": value.latest_activity, "running": value.running,
                  "last_flushed_lsn": value.last_flushed_lsn}
            for key, value in self.walreceivers.items()
        }
        self.state["pg_receivexlogs"] = {
            key: {"latest_activity": value.latest_activity, "running": value.running}
            for key, value in self.receivexlogs.items()
        }
        self.state["pg_basebackups"] = {
            key: {"latest_activity": value.latest_activity, "running": value.running}
            for key, value in self.basebackups.items()
        }
        self.state["compressors"] = [compressor.state for compressor in self.compressors]
        self.state["transfer_agents"] = [ta.state for ta in self.transfer_agents]
        self.state["queues"] = {
            "compression_queue": self.compression_queue.qsize(),
            "transfer_queue": self.transfer_queue.qsize(),
        }
        self.log.debug("Writing JSON state file to %r", state_file_path)
        write_json_file(state_file_path, self.state)
        self.log.debug("Wrote JSON state file to disk, took %.4fs", time.time() - start_time)

    def load_config(self, _signal=None, _frame=None):  # pylint: disable=unused-argument
        self.log.debug("Loading JSON config from: %r, signal: %r", self.config_path, _signal)
        try:
            new_config = config.read_json_config_file(self.config_path)
        except (InvalidConfigurationError, subprocess.CalledProcessError, UnicodeDecodeError) as ex:
            self.log.exception("Invalid config file %r: %s: %s", self.config_path, ex.__class__.__name__, ex)
            # if we were called by a signal handler we'll ignore (and log)
            # the error and hope the user fixes the configuration before
            # restarting pghoard.
            if _signal is not None:
                return
            if isinstance(ex, InvalidConfigurationError):
                raise
            raise InvalidConfigurationError(self.config_path)

        self.config = new_config
        if self.config.get("syslog") and not self.syslog_handler:
            self.syslog_handler = logutil.set_syslog_handler(
                address=self.config.get("syslog_address", "/dev/log"),
                facility=self.config.get("syslog_facility", "local2"),
                logger=logging.getLogger(),
            )
        # NOTE: getLevelName() also converts level names to numbers
        self.log_level = logging.getLevelName(self.config["log_level"])
        try:
            logging.getLogger().setLevel(self.log_level)
        except ValueError:
            self.log.exception("Problem with log_level: %r", self.log_level)

        # Setup monitoring clients
        self.metrics = metrics.Metrics(
            statsd=self.config.get("statsd", None),
            pushgateway=self.config.get("pushgateway", None),
            prometheus=self.config.get("prometheus", None))

        for thread in self._get_all_threads():
            thread.config = new_config
            thread.site_transfers = {}

        self.log.debug("Loaded config: %r from: %r", self.config, self.config_path)

    def _get_all_threads(self):
        all_threads = []
        if hasattr(self, "webserver"):  # on first config load webserver isn't initialized yet
            all_threads.append(self.webserver)
        all_threads.extend(self.basebackups.values())
        all_threads.extend(self.receivexlogs.values())
        all_threads.extend(self.walreceivers.values())
        all_threads.extend(self.compressors)
        all_threads.extend(self.transfer_agents)
        return all_threads

    def quit(self, _signal=None, _frame=None):  # pylint: disable=unused-argument
        self.log.warning("Quitting, signal: %r", _signal)
        self.running = False
        self.inotify.running = False
        all_threads = self._get_all_threads()
        for t in all_threads:
            t.running = False
        # Write state file in the end so we get the last known state
        self.write_backup_state_to_json_file()
        for t in all_threads:
            if t.is_alive():
                t.join()
        if self.mp_manager:
            self.mp_manager.shutdown()
            self.mp_manager = None