示例#1
0
class LoadBalancingServer(ValidateModelMixin, TimeStampedModel):
    """
    A model representing a configured load-balancing server.
    """
    objects = LoadBalancingServerManager()

    domain = models.CharField(max_length=100, unique=True)
    # The username used to ssh into the server
    ssh_username = models.CharField(max_length=32)
    # Whether new backends can be assigned to this load-balancing server
    accepts_new_backends = models.BooleanField(default=True)
    # A random postfix appended to the haproxy configuration file names to make collisions
    # impossible.
    fragment_name_postfix = models.CharField(max_length=8, blank=True)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = ModelLoggerAdapter(logger, {'obj': self})

    def set_field_defaults(self):
        if not self.fragment_name_postfix:
            # Set a unique fragment_name_postfix to avoid clashes between multiple instance
            # managers sharing the same load balancer.
            bits = self._meta.get_field("fragment_name_postfix").max_length * 4
            self.fragment_name_postfix = format(random.getrandbits(bits), "x")
        super().set_field_defaults()

    def __str__(self):
        return self.domain

    def get_log_message_annotation(self):
        """
        Annotate log messages for the load-balancing server.
        """
        return "load_balancer={} ({!s:.15})".format(self.pk, self.domain)

    def get_instances(self):
        """
        Yield all instances configured to use this load balancer.
        """
        # Local import due to avoid problems with circular dependencies.
        from instance.models.mixins.load_balanced import LoadBalancedInstance

        for field in self._meta.get_fields():
            if field.one_to_many and issubclass(field.related_model, LoadBalancedInstance):
                yield from getattr(self, field.get_accessor_name()).iterator()

    def get_configuration(self, triggering_instance_id=None):
        """
        Collect the backend maps and configuration fragments from all associated instances.

        This function also appends fragment_name_postfix to all backend names to avoid name clashes
        between multiple instance managers using the same load balancer (e.g. for the integration
        tests).

        The triggering_instance_id indicates the id of the instance reference that initiated the
        reconfiguration of the load balancer.
        """
        backend_map = []
        backend_conf = []
        for instance in self.get_instances():
            triggered_by_instance = instance.ref.pk == triggering_instance_id
            map_entries, conf_entries = instance.get_load_balancer_configuration(
                triggered_by_instance
            )
            backend_map.extend(
                " ".join([domain.lower(), backend + self.fragment_name_postfix])
                for domain, backend in map_entries
            )
            backend_conf.extend(
                "backend {}\n{}\n".format(backend + self.fragment_name_postfix, conf)
                for backend, conf in conf_entries
            )
        return "\n".join(backend_map), "\n".join(backend_conf)

    def get_ansible_vars(self, triggering_instance_id=None):
        """
        Render the configuration script to be executed on the load balancer.

        The triggering_instance_id indicates the id of the instance reference that initiated the
        reconfiguration of the load balancer.
        """
        backend_map, backend_conf = self.get_configuration(triggering_instance_id)
        fragment_name = settings.LOAD_BALANCER_FRAGMENT_NAME_PREFIX + self.fragment_name_postfix
        return (
            "FRAGMENT_NAME: {fragment_name}\n"
            "BACKEND_CONFIG_FRAGMENT: |\n"
            "{backend_conf}\n"
            "BACKEND_MAP_FRAGMENT: |\n"
            "{backend_map}\n"
        ).format(
            fragment_name=fragment_name,
            backend_conf=textwrap.indent(backend_conf, "  "),
            backend_map=textwrap.indent(backend_map, "  "),
        )

    def run_playbook(self, ansible_vars):
        """
        Run the playbook to perform the server reconfiguration.

        This is factored out into a separate method so it can be mocked out in the tests.
        """
        playbook_path = pathlib.Path(settings.SITE_ROOT) / "playbooks/load_balancer_conf/load_balancer_conf.yml"
        with cache.lock("load_balancer_reconfigure:{}".format(self.domain), timeout=900):
            returncode = ansible.capture_playbook_output(
                requirements_path=str(playbook_path.parent / "requirements.txt"),
                inventory_str=self.domain,
                vars_str=ansible_vars,
                playbook_path=str(playbook_path),
                username=self.ssh_username,
                logger_=self.logger,
            )
        if returncode != 0:
            self.logger.error("Playbook to reconfigure load-balancing server %s failed.", self)
            raise ReconfigurationFailed

    def reconfigure(self, triggering_instance_id=None):
        """
        Regenerate the configuration fragments on the load-balancing server.

        The triggering_instance_id indicates the id of the instance reference that initiated the
        reconfiguration of the load balancer.
        """
        self.logger.info("Reconfiguring load-balancing server %s", self.domain)
        self.run_playbook(self.get_ansible_vars(triggering_instance_id))

    def deconfigure(self):
        """
        Remove the configuration fragment from the load-balancing server.
        """
        fragment_name = settings.LOAD_BALANCER_FRAGMENT_NAME_PREFIX + self.fragment_name_postfix
        self.run_playbook(
            "FRAGMENT_NAME: {fragment_name}\nREMOVE_FRAGMENT: True".format(fragment_name=fragment_name)
        )

    def delete(self, *args, **kwargs):
        """
        Delete the LoadBalancingServer from the database.
        """
        self.deconfigure()
        super().delete(*args, **kwargs)  # pylint: disable=no-member
示例#2
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.logger = ModelLoggerAdapter(logger, {'obj': self})
示例#3
0
class WatchedPullRequest(models.Model):
    """
    Represents a single watched pull request; holds the ID of the Instance created for that PR,
    if any
    """
    # TODO: Remove 'ref_type' ?
    # TODO: Remove parameters from 'update_instance_from_pr'; make it fetch PR details from the
    # api (including the head commit sha hash, which does not require a separate API call as
    # is currently used.)
    branch_name = models.CharField(max_length=50, default='master')
    ref_type = models.CharField(max_length=50, default='heads')
    github_organization_name = models.CharField(max_length=200, db_index=True)
    github_repository_name = models.CharField(max_length=200, db_index=True)
    github_pr_url = models.URLField(blank=False)
    instance = models.OneToOneField('instance.OpenEdXInstance', null=True, blank=True, on_delete=models.SET_NULL)

    objects = WatchedPullRequestQuerySet.as_manager()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = ModelLoggerAdapter(logger, {'obj': self})

    @property
    def fork_name(self):
        """
        Fork name (eg. 'open-craft/edx-platform')
        """
        return '{0.github_organization_name}/{0.github_repository_name}'.format(self)

    @property
    def reference_name(self):
        """
        A descriptive name for the PR, which includes meaningful attributes
        """
        return '{0.github_organization_name}/{0.branch_name}'.format(self)

    @property
    def github_base_url(self):
        """
        Base GitHub URL of the fork (eg. 'https://github.com/open-craft/edx-platform')
        """
        return 'https://github.com/{0.fork_name}'.format(self)

    @property
    def github_branch_url(self):
        """
        GitHub URL of the branch tree
        """
        return '{0.github_base_url}/tree/{0.branch_name}'.format(self)

    @property
    def github_pr_number(self):
        """
        Get the PR number from the URL of the PR.
        """
        if not self.github_pr_url:
            return None
        return int(self.github_pr_url.split('/')[-1])

    @property
    def target_fork_name(self):
        """
        Get the full name of the target repo/fork (e.g. 'edx/edx-platform')
        """
        # Split up a URL like https://github.com/edx/edx-platform/pull/12345678
        org, repo, pull, dummy = self.github_pr_url.split('/')[-4:]
        assert pull == "pull"
        return "{}/{}".format(org, repo)

    @property
    def repository_url(self):
        """
        URL of the git repository (eg. 'https://github.com/open-craft/edx-platform.git')
        """
        return '{0.github_base_url}.git'.format(self)

    @property
    def updates_feed(self):
        """
        RSS/Atom feed of commits made on the repository/branch
        """
        return '{0.github_base_url}/commits/{0.branch_name}.atom'.format(self)

    def get_log_message_annotation(self):
        """
        Format a log message annotation for this PR.
        """
        if self.instance:
            return self.instance.get_log_message_annotation()
        return None

    def get_branch_tip(self):
        """
        Get the `commit_id` of the current tip of the branch
        """
        self.logger.info('Fetching commit ID of the tip of branch %s', self.branch_name)
        try:
            new_commit_id = github.get_commit_id_from_ref(
                self.fork_name,
                self.branch_name,
                ref_type=self.ref_type)
        except github.ObjectDoesNotExist:
            self.logger.error("Branch '%s' not found. Has it been deleted on GitHub?",
                              self.branch_name)
            raise

        return new_commit_id

    def set_fork_name(self, fork_name):
        """
        Set the organization and repository based on the GitHub fork name
        """
        assert not self.github_organization_name
        assert not self.github_repository_name
        self.logger.info('Setting fork name: %s', fork_name)
        fork_org, fork_repo = github.fork_name2tuple(fork_name)
        self.github_organization_name = fork_org
        self.github_repository_name = fork_repo

    def update_instance_from_pr(self, pr):
        """
        Update/create the associated sandbox instance with settings from the given pull request.

        This will not spawn a new AppServer.
        This method will automatically save this WatchedPullRequest's 'instance' field.
        """
        # The following fields should never change:
        assert self.github_pr_url == pr.github_pr_url
        assert self.fork_name == pr.fork_name
        assert self.branch_name == pr.branch_name
        # Create an instance if necessary:
        instance = self.instance or OpenEdXInstance()
        instance.internal_lms_domain = generate_internal_lms_domain('pr{number}.sandbox'.format(number=pr.number))
        instance.edx_platform_repository_url = self.repository_url
        instance.edx_platform_commit = self.get_branch_tip()
        instance.name = (
            'PR#{pr.number}: {pr.truncated_title} ({pr.username}) - {i.reference_name} ({commit_short_id})'
            .format(pr=pr, i=self, commit_short_id=instance.edx_platform_commit[:7])
        )
        instance.configuration_extra_settings = pr.extra_settings
        instance.use_ephemeral_databases = pr.use_ephemeral_databases(instance.domain)
        instance.configuration_source_repo_url = pr.get_extra_setting(
            'edx_ansible_source_repo', default=instance.configuration_source_repo_url
        )
        instance.configuration_version = pr.get_extra_setting(
            'configuration_version', default=instance.configuration_version
        )
        # Save atomically. (because if the instance gets created but self.instance failed to
        # update, then any subsequent call to update_instance_from_pr() would try to create
        # another instance, which would fail due to unique domain name constraints.)
        with transaction.atomic():
            instance.save()
            if not self.instance:
                self.instance = instance
                self.save(update_fields=["instance"])  # pylint: disable=no-member
示例#4
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self.logger = ModelLoggerAdapter(logger, {'obj': self})
示例#5
0
class LoadBalancingServer(ValidateModelMixin, TimeStampedModel):
    """
    A model representing a configured load-balancing server.
    """
    objects = LoadBalancingServerManager()

    domain = models.CharField(max_length=100, unique=True)
    # The username used to ssh into the server
    ssh_username = models.CharField(max_length=32)
    # Whether new backends can be assigned to this load-balancing server
    accepts_new_backends = models.BooleanField(default=True)
    # A random postfix appended to the haproxy configuration file names to avoid clashes between
    # multiple instance managers (or multiple concurrently running integration tests) sharing the
    # same load balancer.
    fragment_name_postfix = models.CharField(
        max_length=8,
        blank=True,
        default=functools.partial(generate_fragment_name, length=8),
    )

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = ModelLoggerAdapter(logger, {'obj': self})

    def __str__(self):
        return self.domain

    def get_log_message_annotation(self):
        """
        Annotate log messages for the load-balancing server.
        """
        return "load_balancer={} ({!s:.15})".format(self.pk, self.domain)

    def get_instances(self):
        """
        Yield all instances configured to use this load balancer.
        """
        # Local import due to avoid problems with circular dependencies.
        from instance.models.mixins.load_balanced import LoadBalancedInstance

        for field in self._meta.get_fields():
            if field.one_to_many and issubclass(field.related_model, LoadBalancedInstance):
                yield from getattr(self, field.get_accessor_name()).iterator()

    def get_configuration(self, triggering_instance_id=None):
        """
        Collect the backend maps and configuration fragments from all associated instances.

        This function also appends fragment_name_postfix to all backend names to avoid name clashes
        between multiple instance managers using the same load balancer (e.g. for the integration
        tests).

        The triggering_instance_id indicates the id of the instance reference that initiated the
        reconfiguration of the load balancer.
        """
        backend_map = []
        backend_conf = []
        for instance in self.get_instances():
            triggered_by_instance = instance.ref.pk == triggering_instance_id
            map_entries, conf_entries = instance.get_load_balancer_configuration(
                triggered_by_instance
            )
            backend_map.extend(
                " ".join([domain.lower(), backend + self.fragment_name_postfix])
                for domain, backend in map_entries
            )
            backend_conf.extend(
                "backend {}\n{}\n".format(backend + self.fragment_name_postfix, conf)
                for backend, conf in conf_entries
            )
        return "\n".join(backend_map), "\n".join(backend_conf)

    def get_ansible_vars(self, triggering_instance_id=None):
        """
        Render the configuration script to be executed on the load balancer.

        The triggering_instance_id indicates the id of the instance reference that initiated the
        reconfiguration of the load balancer.
        """
        backend_map, backend_conf = self.get_configuration(triggering_instance_id)
        fragment_name = settings.LOAD_BALANCER_FRAGMENT_NAME_PREFIX + self.fragment_name_postfix
        return (
            "FRAGMENT_NAME: {fragment_name}\n"
            "BACKEND_CONFIG_FRAGMENT: |\n"
            "{backend_conf}\n"
            "BACKEND_MAP_FRAGMENT: |\n"
            "{backend_map}\n"
        ).format(
            fragment_name=fragment_name,
            backend_conf=textwrap.indent(backend_conf, "  "),
            backend_map=textwrap.indent(backend_map, "  "),
        )

    def run_playbook(self, ansible_vars):
        """
        Run the playbook to perform the server reconfiguration.

        This is factored out into a separate method so it can be mocked out in the tests.
        """
        playbook_path = pathlib.Path(settings.SITE_ROOT) / "playbooks/load_balancer_conf/load_balancer_conf.yml"
        returncode = ansible.capture_playbook_output(
            requirements_path=str(playbook_path.parent / "requirements.txt"),
            inventory_str=self.domain,
            vars_str=ansible_vars,
            playbook_path=str(playbook_path),
            username=self.ssh_username,
            logger_=self.logger,
        )
        if returncode != 0:
            self.logger.error("Playbook to reconfigure load-balancing server %s failed.", self)
            raise ReconfigurationFailed

    def _configuration_lock(self):
        """
        A Redis lock to protect reconfigurations of this load balancer instance.
        """
        return cache.lock("load_balancer_reconfigure:{}".format(self.domain), timeout=900)

    def reconfigure(self, triggering_instance_id=None):
        """
        Regenerate the configuration fragments on the load-balancing server.

        The triggering_instance_id indicates the id of the instance reference that initiated the
        reconfiguration of the load balancer.
        """
        self.logger.info("Reconfiguring load-balancing server %s", self.domain)
        with self._configuration_lock():
            self.run_playbook(self.get_ansible_vars(triggering_instance_id))

    def deconfigure(self):
        """
        Remove the configuration fragment from the load-balancing server.
        """
        fragment_name = settings.LOAD_BALANCER_FRAGMENT_NAME_PREFIX + self.fragment_name_postfix
        with self._configuration_lock():
            self.run_playbook(
                "FRAGMENT_NAME: {fragment_name}\nREMOVE_FRAGMENT: True".format(fragment_name=fragment_name)
            )

    def delete(self, *args, **kwargs):
        """
        Delete the LoadBalancingServer from the database.
        """
        self.deconfigure()
        super().delete(*args, **kwargs)
示例#6
0
class WatchedPullRequest(models.Model):
    """
    Represents a single watched pull request; holds the ID of the Instance created for that PR,
    if any
    """
    # TODO: Remove 'ref_type' ?
    # TODO: Remove parameters from 'update_instance_from_pr'; make it fetch PR details from the
    # api (including the head commit sha hash, which does not require a separate API call as
    # is currently used.)
    watched_fork = models.ForeignKey(WatchedFork,
                                     blank=True,
                                     null=True,
                                     on_delete=models.CASCADE)
    branch_name = models.CharField(max_length=255, default='master')
    ref_type = models.CharField(max_length=50, default='heads')
    github_organization_name = models.CharField(max_length=200, db_index=True)
    github_repository_name = models.CharField(max_length=200, db_index=True)
    github_pr_url = models.URLField(blank=False)
    instance = models.OneToOneField('instance.OpenEdXInstance',
                                    null=True,
                                    blank=True,
                                    on_delete=models.SET_NULL)

    objects = WatchedPullRequestQuerySet.as_manager()

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = ModelLoggerAdapter(logger, {'obj': self})

    @property
    def fork_name(self):
        """
        Fork name (eg. 'open-craft/edx-platform')
        """
        return '{0.github_organization_name}/{0.github_repository_name}'.format(
            self)

    @property
    def reference_name(self):
        """
        A descriptive name for the PR, which includes meaningful attributes
        """
        return '{0.github_organization_name}/{0.branch_name}'.format(self)

    @property
    def github_base_url(self):
        """
        Base GitHub URL of the fork (eg. 'https://github.com/open-craft/edx-platform')
        """
        return 'https://github.com/{0.fork_name}'.format(self)

    @property
    def github_branch_url(self):
        """
        GitHub URL of the branch tree
        """
        return '{0.github_base_url}/tree/{0.branch_name}'.format(self)

    @property
    def github_pr_number(self):
        """
        Get the PR number from the URL of the PR.
        """
        if not self.github_pr_url:
            return None
        return int(self.github_pr_url.split('/')[-1])

    @property
    def target_fork_name(self):
        """
        Get the full name of the target repo/fork (e.g. 'edx/edx-platform')
        """
        # Split up a URL like https://github.com/edx/edx-platform/pull/12345678
        org, repo, pull, dummy = self.github_pr_url.split('/')[-4:]
        assert pull == "pull"
        return "{}/{}".format(org, repo)

    @property
    def repository_url(self):
        """
        URL of the git repository (eg. 'https://github.com/open-craft/edx-platform.git')
        """
        return '{0.github_base_url}.git'.format(self)

    @property
    def updates_feed(self):
        """
        RSS/Atom feed of commits made on the repository/branch
        """
        return '{0.github_base_url}/commits/{0.branch_name}.atom'.format(self)

    def get_log_message_annotation(self):
        """
        Format a log message annotation for this PR.
        """
        if self.instance:
            return self.instance.get_log_message_annotation()
        return None

    def get_branch_tip(self):
        """
        Get the `commit_id` of the current tip of the branch
        """
        self.logger.info('Fetching commit ID of the tip of branch %s',
                         self.branch_name)
        try:
            new_commit_id = github.get_commit_id_from_ref(
                self.fork_name, self.branch_name, ref_type=self.ref_type)
        except github.ObjectDoesNotExist:
            self.logger.error(
                "Branch '%s' not found. Has it been deleted on GitHub?",
                self.branch_name)
            raise

        return new_commit_id

    def set_fork_name(self, fork_name):
        """
        Set the organization and repository based on the GitHub fork name
        """
        assert not self.github_organization_name
        assert not self.github_repository_name
        self.logger.info('Setting fork name: %s', fork_name)
        fork_org, fork_repo = github.fork_name2tuple(fork_name)
        self.github_organization_name = fork_org
        self.github_repository_name = fork_repo

    def update_instance_from_pr(self, pr):
        """
        Update/create the associated sandbox instance with settings from the given pull request.

        This will not spawn a new AppServer.
        This method will automatically save this WatchedPullRequest's 'instance' field.
        """
        # The following fields should never change:
        assert self.github_pr_url == pr.github_pr_url
        assert self.fork_name == pr.fork_name
        assert self.branch_name == pr.branch_name
        # Create an instance if necessary:
        instance = self.instance or OpenEdXInstance()
        is_external_pr = self.watched_fork is None
        instance.internal_lms_domain = generate_internal_lms_domain(
            '{prefix}pr{number}.sandbox'.format(
                prefix='ext' if is_external_pr else '',
                number=pr.number,
            ))
        instance.edx_platform_repository_url = self.repository_url
        instance.edx_platform_commit = self.get_branch_tip()
        instance.name = (
            '{prefix}PR#{pr.number}: {pr.truncated_title} ({pr.username}) - {i.reference_name} ({commit_short_id})'
            .format(
                pr=pr,
                i=self,
                commit_short_id=instance.edx_platform_commit[:7],
                prefix='EXT' if is_external_pr else '',
            ))
        if is_external_pr:
            instance.configuration_extra_settings = pr.extra_settings
        else:
            instance.configuration_extra_settings = yaml_merge(
                self.watched_fork.configuration_extra_settings,
                pr.extra_settings)
        if not instance.ref.creator or not instance.ref.owner:
            try:
                user = UserProfile.objects.get(github_username=pr.username)
                instance.ref.creator = user
                instance.ref.owner = user.organization
            except UserProfile.DoesNotExist:
                # PR is not associated with an Ocim user
                pass
        # Configuration repo and version and edx release follow this precedence:
        # 1) PR settings. 2) WatchedFork settings. 3) instance model defaults
        instance.configuration_source_repo_url = pr.get_extra_setting(
            'edx_ansible_source_repo',
            default=((self.watched_fork
                      and self.watched_fork.configuration_source_repo_url)
                     or instance.configuration_source_repo_url))
        instance.configuration_version = pr.get_extra_setting(
            'configuration_version',
            default=((self.watched_fork
                      and self.watched_fork.configuration_version)
                     or instance.configuration_version))
        instance.openedx_release = pr.get_extra_setting(
            'openedx_release',
            default=((self.watched_fork and self.watched_fork.openedx_release)
                     or instance.openedx_release))
        # Save atomically. (because if the instance gets created but self.instance failed to
        # update, then any subsequent call to update_instance_from_pr() would try to create
        # another instance, which would fail due to unique domain name constraints.)
        with transaction.atomic():
            instance.save()
            if not self.instance:
                self.instance = instance
                self.save(update_fields=["instance"])
示例#7
0
class LoadBalancingServer(ValidateModelMixin, TimeStampedModel):
    """
    A model representing a configured load-balancing server.
    """
    objects = LoadBalancingServerManager()

    domain = models.CharField(max_length=100, unique=True)

    ssh_username = models.CharField(
        max_length=32,
        help_text='The username used to SSH into the server.'
    )

    accepts_new_backends = models.BooleanField(
        default=False,
        help_text='Whether new backends can be assigned to this load-balancing server.'
    )

    fragment_name_postfix = models.CharField(
        max_length=8,
        blank=True,
        default=functools.partial(generate_fragment_name, length=8),
        help_text=(
            'A random postfix appended to the haproxy configuration file names to avoid clashes between '
            'multiple instance managers (or multiple concurrently running integration tests) sharing the '
            'same load balancer.'
        )
    )

    configuration_version = models.PositiveIntegerField(
        default=1,
        help_text=(
            'The current version of configuration for this load balancer. '
            'The version value is the total number of requests ever made to reconfigure the load balancer.'
        )
    )

    deployed_configuration_version = models.PositiveIntegerField(
        default=1,
        help_text=(
            'The currently active configuration version of the load balancer. '
            'If it is less than the configuration version, the load balancer is dirty. '
            'If it is equal to it, then no new reconfiguration is currently required.'
        )
    )

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = ModelLoggerAdapter(logger, {'obj': self})

    def __str__(self):
        return self.domain

    def get_log_message_annotation(self):
        """
        Annotate log messages for the load-balancing server.
        """
        return "load_balancer={} ({!s:.15})".format(self.pk, self.domain)

    def get_instances(self):
        """
        Yield all instances configured to use this load balancer.
        """
        # Local import due to avoid problems with circular dependencies.
        from instance.models.mixins.load_balanced import LoadBalancedInstance

        for field in self._meta.get_fields():
            if field.one_to_many and issubclass(field.related_model, LoadBalancedInstance):
                yield from getattr(self, field.get_accessor_name()).iterator()

    def get_configuration(self, triggering_instance_id=None):
        """
        Collect the backend maps and configuration fragments from all associated instances.

        This function also appends fragment_name_postfix to all backend names to avoid name clashes
        between multiple instance managers using the same load balancer (e.g. for the integration
        tests).

        The triggering_instance_id indicates the id of the instance reference that initiated the
        reconfiguration of the load balancer.
        """
        backend_map = []
        backend_conf = []
        for instance in self.get_instances():
            triggered_by_instance = instance.ref.pk == triggering_instance_id
            map_entries, conf_entries = instance.get_load_balancer_configuration(
                triggered_by_instance
            )
            backend_map.extend(
                " ".join([domain.lower(), backend + self.fragment_name_postfix])
                for domain, backend in map_entries
            )
            backend_conf.extend(
                "backend {}\n{}\n".format(backend + self.fragment_name_postfix, conf)
                for backend, conf in conf_entries
            )
        return "\n".join(backend_map), "\n".join(backend_conf)

    def get_ansible_vars(self, triggering_instance_id=None):
        """
        Render the configuration script to be executed on the load balancer.

        The triggering_instance_id indicates the id of the instance reference that initiated the
        reconfiguration of the load balancer.
        """
        backend_map, backend_conf = self.get_configuration(triggering_instance_id)
        fragment_name = settings.LOAD_BALANCER_FRAGMENT_NAME_PREFIX + self.fragment_name_postfix
        return (
            "FRAGMENT_NAME: {fragment_name}\n"
            "BACKEND_CONFIG_FRAGMENT: |\n"
            "{backend_conf}\n"
            "BACKEND_MAP_FRAGMENT: |\n"
            "{backend_map}\n"
        ).format(
            fragment_name=fragment_name,
            backend_conf=textwrap.indent(backend_conf, "  "),
            backend_map=textwrap.indent(backend_map, "  "),
        )

    def run_playbook(self, ansible_vars):
        """
        Run the playbook to perform the server reconfiguration.

        This is factored out into a separate method so it can be mocked out in the tests.
        """
        playbook_path = pathlib.Path(settings.SITE_ROOT) / "playbooks/load_balancer_conf/load_balancer_conf.yml"
        returncode = ansible.capture_playbook_output(
            requirements_path=str(playbook_path.parent / "requirements.txt"),
            inventory_str=self.domain,
            vars_str=ansible_vars,
            playbook_path=str(playbook_path),
            username=self.ssh_username,
            logger_=self.logger,
        )
        if returncode != 0:
            self.logger.error("Playbook to reconfigure load-balancing server %s failed.", self)
            raise ReconfigurationFailed

    def reconfigure(self, triggering_instance_id=None, mark_dirty=True):
        """
        Regenerate the configuration fragments on the load-balancing server.

        The triggering_instance_id indicates the id of the instance reference that initiated the
        reconfiguration of the load balancer.

        The mark_dirty flag indicates whether the LB configuration should be marked as dirty.  If
        this method is called because the configuration changed, the flag should be set to True (the
        default).  If this method is called because the LB was marked dirty earlier, the flag
        should be set to False.
        """
        if mark_dirty:
            # We need to use an F expression here.  The problem is not other processes trying to
            # increase this counter concurrently – that wouldn't matter, since we don't care whether
            # we increase this counter by one or by two, since both marks the LB as dirty.  However, if
            # another process is making a completely unrelated change to the LB object we might lose
            # the increment altogether.
            LoadBalancingServer.objects.filter(pk=self.pk).update(
                configuration_version=models.F("configuration_version") + 1
            )

        try:
            with self._configuration_lock(blocking=False):
                # Memorize the configuration version, in case new threads change it.
                self.refresh_from_db()
                candidate_configuration_version = self.configuration_version
                self.logger.info("Reconfiguring load-balancing server %s", self.domain)
                self.run_playbook(self.get_ansible_vars(triggering_instance_id))
                LoadBalancingServer.objects.filter(pk=self.pk).update(
                    deployed_configuration_version=candidate_configuration_version
                )
                self.refresh_from_db()
        except OtherReconfigurationInProgress:
            pass

    def deconfigure(self):
        """
        Remove the configuration fragment from the load-balancing server.
        """
        self.logger.info("Deconfiguring load-balancing server %s", self.domain)
        fragment_name = settings.LOAD_BALANCER_FRAGMENT_NAME_PREFIX + self.fragment_name_postfix
        with self._configuration_lock():
            self.run_playbook(
                "FRAGMENT_NAME: {fragment_name}\nREMOVE_FRAGMENT: True".format(fragment_name=fragment_name)
            )

    def delete(self, *args, **kwargs):
        """
        Delete the LoadBalancingServer from the database.
        """
        self.deconfigure()
        super().delete(*args, **kwargs)

    @contextlib.contextmanager
    def _configuration_lock(self, *, blocking=True):
        """
        A Redis lock to protect reconfigurations of this load balancer instance.
        """
        lock = cache.lock(
            "load_balancer_reconfigure:{}".format(self.domain),
            timeout=settings.REDIS_LOCK_TIMEOUT,
        )
        if not lock.acquire(blocking):
            raise OtherReconfigurationInProgress
        try:
            yield lock
        finally:
            lock.release()
示例#8
0
class Server(ValidateModelMixin, TimeStampedModel):
    """
    A single server VM
    """
    name_prefix = models.SlugField(max_length=20, blank=False)

    Status = Status
    status = ModelResourceStateDescriptor(state_classes=Status.states,
                                          default_state=Status.Pending,
                                          model_field_name='_status')
    _status = models.CharField(
        max_length=20,
        default=status.default_state_class.state_id,
        choices=status.model_field_choices,
        db_index=True,
        db_column='status',
    )
    # State transitions:
    _status_to_building = status.transition(from_states=(Status.Pending,
                                                         Status.Unknown),
                                            to_state=Status.Building)
    _status_to_build_failed = status.transition(from_states=(Status.Building,
                                                             Status.Unknown),
                                                to_state=Status.BuildFailed)
    _status_to_booting = status.transition(from_states=(Status.Building,
                                                        Status.Ready,
                                                        Status.Unknown),
                                           to_state=Status.Booting)
    _status_to_ready = status.transition(from_states=(Status.Booting,
                                                      Status.Unknown),
                                         to_state=Status.Ready)
    _status_to_terminated = status.transition(to_state=Status.Terminated)
    _status_to_unknown = status.transition(from_states=(Status.Building,
                                                        Status.Booting,
                                                        Status.Ready),
                                           to_state=Status.Unknown)

    objects = ServerQuerySet().as_manager()

    class Meta:
        abstract = True

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = ModelLoggerAdapter(logger, {'obj': self})

    @property
    def name(self):
        """ Get a name for this server (slug-friendly) """
        assert self.id is not None
        return "{prefix}-{num}".format(prefix=self.name_prefix, num=self.id)

    @property
    def event_context(self):
        """
        Context dictionary to include in events
        """
        return {'server_id': self.pk}

    def get_log_message_annotation(self):
        """
        Format a log line annotation for this server.
        """
        return 'server={} ({!s:.20})'.format(self.pk, self.name)

    def sleep_until(self, condition, timeout=3600):
        """
        Sleep in a loop until condition related to server status is fulfilled,
        or until timeout (provided in seconds) is reached.

        Raises an exception if the desired condition can not be fulfilled.
        This can happen if the server is in a steady state (i.e., a state that is not expected to change)
        that does not fulfill the desired condition.

        The default timeout is 1h.

        Use as follows:

            server.sleep_until(lambda: server.status.is_steady_state)
            server.sleep_until(lambda: server.status.accepts_ssh_commands)
        """
        # Check if we received a valid timeout value
        # to avoid the possibility of entering an infinite loop (if timeout is negative)
        # or reaching the timeout right away (if timeout is zero)
        assert timeout > 0, "Timeout must be greater than 0 to be able to do anything useful"

        self.logger.info(
            'Waiting to reach status from which we can proceed...')

        while timeout > 0:
            self.update_status()
            if condition():
                self.logger.info(
                    'Reached appropriate status ({name}). Proceeding.'.format(
                        name=self.status.name))
                return
            else:
                if self.status.is_steady_state:
                    raise SteadyStateException(
                        "The current status ({name}) does not fulfill the desired condition "
                        "and is not expected to change.".format(
                            name=self.status.name))
            time.sleep(1)
            timeout -= 1

        # If we get here, this means we've reached the timeout
        raise TimeoutError(
            "Waited {minutes:.2f} to reach appropriate status, and got nowhere. "
            "Aborting with a status of {status}.".format(
                minutes=timeout / 60, status=self.status.name))

    def save(self, *args, **kwargs):
        """
        Save this Server
        """
        super().save(*args, **kwargs)
        publish_data('notification', {
            'type': 'server_update',
            'server_pk': self.pk,
        })

    def update_status(self):
        """
        Check the current status and update it if it has changed
        """
        raise NotImplementedError
示例#9
0
class Server(ValidateModelMixin, TimeStampedModel):
    """
    A single server VM
    """

    name_prefix = models.SlugField(max_length=20, blank=False)

    Status = Status
    status = ModelResourceStateDescriptor(
        state_classes=Status.states, default_state=Status.Pending, model_field_name="_status"
    )
    _status = models.CharField(
        max_length=20,
        default=status.default_state_class.state_id,
        choices=status.model_field_choices,
        db_index=True,
        db_column="status",
    )
    # State transitions:
    _status_to_building = status.transition(from_states=(Status.Pending, Status.Unknown), to_state=Status.Building)
    _status_to_build_failed = status.transition(
        from_states=(Status.Building, Status.Unknown), to_state=Status.BuildFailed
    )
    _status_to_booting = status.transition(
        from_states=(Status.Building, Status.Ready, Status.Unknown), to_state=Status.Booting
    )
    _status_to_ready = status.transition(from_states=(Status.Booting, Status.Unknown), to_state=Status.Ready)
    _status_to_terminated = status.transition(to_state=Status.Terminated)
    _status_to_unknown = status.transition(
        from_states=(Status.Building, Status.Booting, Status.Ready), to_state=Status.Unknown
    )

    objects = ServerQuerySet().as_manager()

    class Meta:
        abstract = True

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logger = ModelLoggerAdapter(logger, {"obj": self})

    @property
    def name(self):
        """ Get a name for this server (slug-friendly) """
        assert self.id is not None
        return "{prefix}-{num}".format(prefix=self.name_prefix, num=self.id)

    @property
    def event_context(self):
        """
        Context dictionary to include in events
        """
        return {"server_id": self.pk}

    def get_log_message_annotation(self):
        """
        Format a log line annotation for this server.
        """
        return "server={!s:.20}".format(self.name)

    def sleep_until(self, condition, timeout=3600):
        """
        Sleep in a loop until condition related to server status is fulfilled,
        or until timeout (provided in seconds) is reached.

        Raises an exception if the desired condition can not be fulfilled.
        This can happen if the server is in a steady state (i.e., a state that is not expected to change)
        that does not fulfill the desired condition.

        The default timeout is 1h.

        Use as follows:

            server.sleep_until(lambda: server.status.is_steady_state)
            server.sleep_until(lambda: server.status.accepts_ssh_commands)
        """
        # Check if we received a valid timeout value
        # to avoid the possibility of entering an infinite loop (if timeout is negative)
        # or reaching the timeout right away (if timeout is zero)
        assert timeout > 0, "Timeout must be greater than 0 to be able to do anything useful"

        self.logger.info("Waiting to reach status from which we can proceed...")

        while timeout > 0:
            self.update_status()
            if condition():
                self.logger.info("Reached appropriate status ({name}). Proceeding.".format(name=self.status.name))
                return
            else:
                if self.status.is_steady_state:
                    raise SteadyStateException(
                        "The current status ({name}) does not fulfill the desired condition "
                        "and is not expected to change.".format(name=self.status.name)
                    )
            time.sleep(1)
            timeout -= 1

        # If we get here, this means we've reached the timeout
        raise TimeoutError(
            "Waited {minutes:.2f} to reach appropriate status, and got nowhere. "
            "Aborting with a status of {status}.".format(minutes=timeout / 60, status=self.status.name)
        )

    def save(self, *args, **kwargs):
        """
        Save this Server
        """
        super().save(*args, **kwargs)
        publish_data("notification", {"type": "server_update", "server_pk": self.pk})

    def update_status(self):
        """
        Check the current status and update it if it has changed
        """
        raise NotImplementedError