示例#1
0
class LogMixin(CommandMixin('log')):
    def log_init(self, options, task=None):
        if not options:
            options = {}

        self.log_entry = self._log.create(None, command=self.get_full_name())
        self.log_entry.user = self.active_user
        self.log_entry.config = options

        if task:
            self.log_entry.scheduled = True
            self.log_entry.task_id = task.request.id

        self.log_entry.save()

    def log_message(self, data):
        def _create_log_message(command, data):
            if getattr(command, 'log_entry', None):
                command.log_entry.messages.create(data=data)

            if command.exec_parent:
                _create_log_message(command.exec_parent, data)

        _create_log_message(self, data)

    def log_status(self, status):
        self.log_entry.set_status(status)
        self.log_entry.save()
示例#2
0
class KubeExecMixin(CommandMixin('kube_exec')):

    def kube_exec(self, filesystem, executable, options = None, env = None):
        args = self.options.get('args', [])
        if options is None:
            options = []

        if options:
            command = [ executable ] + ensure_list(options) + args
        else:
            command = [ executable ] + args

        config_path = os.path.join(self.manager.data_dir, '.kube', 'config')
        command_env = {
            "KUBECONFIG": filesystem.link(config_path, '.kube')
        }
        if env and isinstance(env, dict):
            command_env = { **command_env, **env }

        success = self.sh(
            command,
            env = command_env,
            cwd = self.manager.module_dir,
            display = True
        )
        if not success:
            self.error("Command {} failed: {}".format(executable, " ".join(command)))
示例#3
0
文件: config.py 项目: zimagi/zimagi
class ConfigMixin(CommandMixin('config')):
    def get_config(self, name, default=None, required=False):
        if not name:
            return default

        config = self.get_instance(self._config, name, required=required)
        if config is None:
            return default

        return config.value
示例#4
0
class LogMixin(CommandMixin('log')):

    log_lock = threading.Lock()

    def log_init(self, options, task=None, log_key=None):
        if not options:
            options = {}

        if self.log_result:
            with self.log_lock:
                if log_key is None:
                    self.log_entry = self._log.create(
                        None, command=self.get_full_name())
                else:
                    self.log_entry = self._log.retrieve(log_key)

                self.log_entry.user = self.active_user
                self.log_entry.config = options
                self.log_entry.status = self._log.model.STATUS_RUNNING
                if task:
                    self.log_entry.scheduled = True
                    self.log_entry.task_id = task.request.id

                self.log_entry.save()

        return self.log_entry.name if self.log_result else '<none>'

    def log_message(self, data, log=True):
        def _create_log_message(command, data, _log):
            if getattr(command, 'log_entry', None) and _log:
                command.log_entry.messages.create(data=data)

            if command.exec_parent:
                _create_log_message(command.exec_parent, data, True)

        if self.log_result:
            with self.log_lock:
                _create_log_message(self, data, log)

    def log_status(self, status, check_log_result=False):
        if not check_log_result or self.log_result:
            with self.log_lock:
                self.log_entry.set_status(status)
                self.log_entry.save()

    def get_status(self):
        return self.log_entry.status if self.log_result else None
示例#5
0
文件: message.py 项目: zimagi/zimagi
class MessageMixin(CommandMixin('message')):
    def check_channel_permission(self):
        communication_spec = copy.deepcopy(
            self.manager.get_spec('channels', {}))

        if self.communication_channel not in communication_spec:
            self.error("Communication channel {} does not exist".format(
                self.communication_channel))

        communication_spec = communication_spec[self.communication_channel]
        if not communication_spec:
            communication_spec = {}

        roles = list(clean_dict(communication_spec, False).keys())
        if not roles:
            return True

        return self.active_user.env_groups.filter(name__in=roles).exists()
示例#6
0
文件: base.py 项目: mbeacom/zimagi
class BaseCommand(
    TerminalMixin,
    renderer.RendererMixin,
    CommandMixin('user'),
    CommandMixin('environment'),
    CommandMixin('group'),
    CommandMixin('config'),
    CommandMixin('module')
):
    display_lock = threading.Lock()
    thread_lock = threading.Lock()


    def __init__(self, name, parent = None):
        self.facade_index = {}

        self.name = name
        self.parent_instance = parent
        self.exec_parent = None

        self.confirmation_message = 'Are you absolutely sure?'
        self.messages = queue.Queue()
        self.parent_messages = None
        self.mute = False

        self.schema = {}
        self.parser = None
        self.options = options.AppOptions(self)
        self.option_map = {}
        self.descriptions = help.CommandDescriptions()

        super().__init__()


    @property
    def manager(self):
        return settings.MANAGER


    def queue(self, msg):
        def _queue_parents(command, data):
            if command.parent_messages:
                command.parent_messages.put(data)
            if command.parent_instance:
                _queue_parents(command.parent_instance, data)

        data = msg.render()
        logger.debug("Adding command queue message: {}".format(data))

        self.messages.put(data)
        _queue_parents(self, data)
        return data

    def flush(self):
        logger.debug("Flushing command queue")
        self.messages.put(None)

    def create_message(self, data, decrypt = True):
        return messages.AppMessage.get(data, decrypt = decrypt)

    def get_messages(self, flush = True):
        messages = []

        if flush:
            self.flush()

        for message in iter(self.messages.get, None):
            messages.append(message)
        return messages


    def add_schema_field(self, name, field, optional = True):
        self.schema[name] = coreapi.Field(
            name = name,
            location = 'form',
            required = not optional,
            schema = field_to_schema(field),
            type = type(field).__name__.lower()
        )

    def get_schema(self):
        return command.CommandSchema(list(self.schema.values()), re.sub(r'\s+', ' ', self.get_description(False)))


    def create_parser(self):

        def display_error(message):
            self.warning(message + "\n")
            self.print_help()
            self.exit(1)

        epilog = self.get_epilog()
        if epilog:
            epilog = "\n".join(wrap_page(epilog))

        parser = CommandParser(
            prog = self.command_color('{} {}'.format(settings.APP_NAME, self.get_full_name())),
            description = "\n".join(wrap_page(
                self.get_description(False),
                init_indent = ' ',
                init_style = self.header_color(),
                indent = '  '
            )),
            epilog = epilog,
            formatter_class = argparse.RawTextHelpFormatter,
            called_from_command_line = True
        )
        parser.error = display_error

        self.add_arguments(parser)
        return parser

    def add_arguments(self, parser):
        self.parser = parser
        self.parse_base()


    def parse(self):
        # Override in subclass
        pass

    def parse_base(self):
        self.option_map = {}

        if not self.parse_passthrough():
            self.parse_verbosity()
            self.parse_no_parallel()
            self.parse_no_color()
            self.parse_debug()
            self.parse_display_width()

            if not settings.API_EXEC:
                self.parse_environment_host()
                self.parse_version()

            self.parse()

    def parse_passthrough(self):
        return False


    def parse_environment_host(self):
        self.parse_variable('environment_host',
            '--host', str,
            "environment host name (default: {})".format(settings.DEFAULT_HOST_NAME),
            value_label = 'NAME',
            default = settings.DEFAULT_HOST_NAME
        )

    @property
    def environment_host(self):
        return self.options.get('environment_host', settings.DEFAULT_HOST_NAME)


    def parse_verbosity(self):
        self.parse_variable('verbosity',
            '--verbosity', int,
            "\n".join(wrap("verbosity level; 0=no output, 1=minimal output, 2=normal output, 3=verbose output", 60)),
            value_label = 'LEVEL',
            default = 2,
            choices = (0, 1, 2, 3)
        )

    @property
    def verbosity(self):
        return self.options.get('verbosity', 2)


    def parse_version(self):
        self.parse_flag('version', '--version', "show environment runtime version information")

    def parse_display_width(self):
        columns, rows = shutil.get_terminal_size(fallback = (settings.DISPLAY_WIDTH, 25))
        self.parse_variable('display_width',
            '--display-width', int,
            "CLI display width (default {} characters)".format(columns),
            value_label = 'WIDTH',
            default = columns
        )

    @property
    def display_width(self):
        return self.options.get('display_width', Runtime.width())

    def parse_no_color(self):
        self.parse_flag('no_color', '--no-color', "don't colorize the command output")

    @property
    def no_color(self):
        return self.options.get('no_color', not Runtime.color())

    def parse_debug(self):
        self.parse_flag('debug', '--debug', 'run in debug mode with error tracebacks')

    @property
    def debug(self):
        return self.options.get('debug', Runtime.debug())

    def parse_no_parallel(self):
        self.parse_flag('no_parallel', '--no-parallel', 'disable parallel processing')

    @property
    def no_parallel(self):
        return self.options.get('no_parallel', not Runtime.parallel())


    def interpolate_options(self):
        return True


    def server_enabled(self):
        return True

    def remote_exec(self):
        return self.server_enabled()

    def groups_allowed(self):
        return False


    def get_version(self):
        return version.VERSION

    def get_priority(self):
        return 1


    def get_parent_name(self):
        if self.parent_instance and self.parent_instance.name != 'root':
            return self.parent_instance.get_full_name()
        return ''

    def get_full_name(self):
        return "{} {}".format(self.get_parent_name(), self.name).strip()

    def get_description(self, overview = False):
        return self.descriptions.get(self.get_full_name(), overview)

    def get_epilog(self):
        return None


    @property
    def active_user(self):
        return self._user.active_user

    def check_access(self, instance, reset = False):
        return self.check_access_by_groups(instance, instance.access_groups(reset))

    def check_access_by_groups(self, instance, groups):
        user_groups = [ Roles.admin ]

        if not groups or self.active_user.name == settings.ADMIN_USER:
            return True

        for group in groups:
            if isinstance(group, (list, tuple)):
                user_groups.extend(list(group))
            else:
                user_groups.append(group)

        if len(user_groups):
            if not self.active_user.env_groups.filter(name__in = user_groups).exists():
                self.warning("Operation {} {} {} access requires at least one of the following roles in environment: {}".format(
                    self.get_full_name(),
                    instance.facade.name,
                    instance.name,
                    ", ".join(user_groups)
                ))
                return False

        return True


    def get_provider(self, type, name, *args, **options):
        type_components = type.split(':')
        type = type_components[0]
        subtype = type_components[1] if len(type_components) > 1 else None

        base_provider = self.manager.index.get_plugin_base(type)
        providers = self.manager.index.get_plugin_providers(type, True)

        if name is None or name in ('help', 'base'):
            provider_class = base_provider
        elif name in providers.keys():
            provider_class = providers[name]
        else:
            self.error("Plugin {} provider {} not supported".format(type, name))

        try:
            return provider_class(type, name, self, *args, **options).context(subtype, self.test)
        except Exception as e:
            self.error("Plugin {} provider {} error: {}".format(type, name, e))


    def print_help(self):
        parser = self.create_parser()
        self.info(parser.format_help())


    def info(self, message, name = None, prefix = None):
        with self.display_lock:
            if not self.mute:
                msg = messages.InfoMessage(str(message),
                    name = name,
                    prefix = prefix,
                    silent = False
                )
                self.queue(msg)

                if not settings.API_EXEC and self.verbosity > 0:
                    msg.display(
                        debug = self.debug,
                        disable_color = self.no_color,
                        width = self.display_width
                    )

    def data(self, label, value, name = None, prefix = None, silent = False):
        with self.display_lock:
            if not self.mute:
                msg = messages.DataMessage(str(label), value,
                    name = name,
                    prefix = prefix,
                    silent = silent
                )
                self.queue(msg)

                if not settings.API_EXEC and self.verbosity > 0:
                    msg.display(
                        debug = self.debug,
                        disable_color = self.no_color,
                        width = self.display_width
                    )

    def silent_data(self, name, value):
        self.data(name, value,
            name = name,
            silent = True
        )

    def notice(self, message, name = None, prefix = None):
        with self.display_lock:
            if not self.mute:
                msg = messages.NoticeMessage(str(message),
                    name = name,
                    prefix = prefix,
                    silent = False
                )
                self.queue(msg)

                if not settings.API_EXEC and self.verbosity > 0:
                    msg.display(
                        debug = self.debug,
                        disable_color = self.no_color,
                        width = self.display_width
                    )

    def success(self, message, name = None, prefix = None):
        with self.display_lock:
            if not self.mute:
                msg = messages.SuccessMessage(str(message),
                    name = name,
                    prefix = prefix,
                    silent = False
                )
                self.queue(msg)

                if not settings.API_EXEC and self.verbosity > 0:
                    msg.display(
                        debug = self.debug,
                        disable_color = self.no_color,
                        width = self.display_width
                    )

    def warning(self, message, name = None, prefix = None):
        with self.display_lock:
            msg = messages.WarningMessage(str(message),
                name = name,
                prefix = prefix,
                silent = False
            )
            self.queue(msg)

            if not settings.API_EXEC and self.verbosity > 0:
                msg.display(
                    debug = self.debug,
                    disable_color = self.no_color,
                    width = self.display_width
                )

    def error(self, message, name = None, prefix = None, terminate = True, traceback = None, error_cls = CommandError, silent = False):
        with self.display_lock:
            msg = messages.ErrorMessage(str(message),
                traceback = traceback,
                name = name,
                prefix = prefix,
                silent = silent
            )
            if not traceback:
                msg.traceback = format_traceback()

            self.queue(msg)

            if not settings.API_EXEC and not silent:
                msg.display(
                    debug = self.debug,
                    disable_color = self.no_color,
                    width = self.display_width
                )

        if terminate:
            raise error_cls('')

    def table(self, data, name = None, prefix = None, silent = False, row_labels = False):
        with self.display_lock:
            if not self.mute:
                msg = messages.TableMessage(data,
                    name = name,
                    prefix = prefix,
                    silent = silent,
                    row_labels = row_labels
                )
                self.queue(msg)

                if not settings.API_EXEC and self.verbosity > 0:
                    msg.display(
                        debug = self.debug,
                        disable_color = self.no_color,
                        width = self.display_width
                    )

    def silent_table(self, name, data):
        self.table(data,
            name = name,
            silent = True
        )

    def confirmation(self, message = None):
        if not settings.API_EXEC and not self.force:
            if not message:
                message = self.confirmation_message

            confirmation = input("{} (type YES to confirm): ".format(message))

            if re.match(r'^[Yy][Ee][Ss]$', confirmation):
                return True

            self.error("User aborted", 'abort')


    def format_fields(self, data, process_func = None):
        fields = self.get_schema().get_fields()
        params = {}

        for key, value in data.items():
            if process_func and callable(process_func):
                key, value = process_func(key, value)

            if value is not None and value != '':
                if key in fields:
                    type = fields[key].type

                    if type in ('dictfield', 'listfield'):
                        params[key] = json.loads(value)
                    elif type == 'booleanfield':
                        params[key] = json.loads(value.lower())
                    elif type == 'integerfield':
                        params[key] = int(value)
                    elif type == 'floatfield':
                        params[key] = float(value)

                if key not in params:
                    params[key] = value
            else:
                params[key] = None

        return params


    def run_list(self, items, callback):
        results = Parallel.list(items, callback, disable_parallel = self.no_parallel)

        if results.aborted:
            for thread in results.errors:
                self.error(thread.error, prefix = "[ {} ]".format(thread.name), traceback = thread.traceback, terminate = False)

            self.error("Parallel run failed", silent = True)

        return results

    def run_exclusive(self, lock_id, callback, error_on_locked = False, wait = True, timeout = 600, interval = 2):
        if not lock_id:
            callback()
        else:
            start_time = time.time()
            current_time = start_time

            while (current_time - start_time) <= timeout:
                try:
                    with db_mutex(lock_id):
                        callback()
                        break

                except DBMutexError:
                    if error_on_locked:
                        self.error("Could not obtain lock for {}".format(lock_id))
                    if not wait:
                        break

                except DBMutexTimeoutError:
                    logger.warning("Task {} completed but the lock timed out".format(lock_id))
                    break

                except Exception as e:
                    DBMutex.objects.filter(lock_id = lock_id).delete()
                    raise e

                time.sleep(interval)
                current_time = time.time()


    def ensure_resources(self):
        for facade_index_name in sorted(self.facade_index.keys()):
            if facade_index_name not in ['00_environment', '00_user']:
                self.facade_index[facade_index_name]._ensure(self)

    def set_options(self, options):
        self.options.clear()

        host = options.pop('environment_host', None)
        if host:
           self.options.add('environment_host', host, False)

        for key, value in options.items():
            self.options.add(key, value)


    def bootstrap_ensure(self):
        return True

    def bootstrap(self, options, primary = False):
        if primary:
            if options.get('debug', False):
                Runtime.debug(True)

            if options.get('no_parallel', False):
                Runtime.parallel(False)

            if options.get('no_color', False):
                Runtime.color(False)

            if options.get('display_width', False):
                Runtime.width(options.get('display_width'))

        self._environment._ensure(self)
        self._user._ensure(self)

        self.set_options(options)
        if primary and self.bootstrap_ensure():
            self.ensure_resources()

    def handle(self, options):
        # Override in subclass
        pass


    def run_from_argv(self, argv):
        parser = self.create_parser()
        args = argv[(len(self.get_full_name().split(' ')) + 1):]

        if not self.parse_passthrough():
            if '--version' in argv:
                return self.manager.index.find_command(
                    'version',
                    main = True
                ).run_from_argv([])

            elif '-h' in argv or '--help' in argv:
                return self.print_help()

            options = vars(parser.parse_args(args))
        else:
            options = { 'args': args }

        try:
            self.bootstrap(options, True)
            self.handle(options, True)
        finally:
            try:
                connections.close_all()
            except ImproperlyConfigured:
                pass
示例#7
0
class NotificationMixin(CommandMixin('notification')):

    def normalize_notify_command(self, command):
        return re.sub(r'\s+', ':', command)

    def collect_notify_groups(self, group_names):
        groups = []
        for name in group_names:
            group = self._group.retrieve(name)
            if not group:
                group = self.group_provider.store(name, {})
            groups.append(group)
        return groups


    def load_notification_users(self, success):
        self.notification_users = {}

        def load_groups(groups):
            for group in ensure_list(groups):
                for user in self._user.filter(groups__name = group):
                    if user.email:
                        self.notification_users[user.name] = user.email

        if self.active_user and self.active_user.email:
            self.notification_users[self.active_user.name] = self.active_user.email

        command = re.sub(r'\s+', ':', self.get_full_name())
        notification = self._notification.retrieve(command)

        if notification:
            notification_groups = list(notification.groups.values_list(
                'group__name', flat = True
            ))
            if notification_groups:
                load_groups(notification_groups)

        groups = self.command_notify
        if groups:
            load_groups(groups)

        if not success:
            if notification:
                notification_failure_groups = list(notification.failure_groups.values_list(
                    'group__name', flat = True
                ))
                if notification_failure_groups:
                    load_groups(notification_failure_groups)

            groups = self.command_notify_failure
            if groups:
                load_groups(groups)

        return list(self.notification_users.values())


    def format_notification_subject(self, success):
        status = 'SUCCESS' if success else 'FAILED'
        return  "{} {}: {}".format(
            settings.EMAIL_SUBJECT_PREFIX,
            status,
            self.get_full_name()
        )

    def format_notification_body(self):
        option_lines = []
        for key, val in self.options.export().items():
            option_lines.append("> {}: {}".format(key, val))

        return "Command: {}\n\nOptions:\n\n{}\n\nMessages:\n\n{}\n\nMore Information:\n\n{}".format(
            self.get_full_name(),
            "\n".join(option_lines),
            "\n".join(self.notification_messages),
            "zimagi log get {}".format(self.log_entry.get_id())
        )


    def send_notifications(self, success):

        def send_mail(recipient):
            try:
                logger.debug("Sending '{}' notification in the background".format(subject))
                send_notification.delay(recipient, subject, body)
            except OperationalError as e:
                logger.debug("Sending '{}' notification now: {}".format(subject, e))
                send_notification(recipient, subject, body)

        if self.log_result and settings.CELERY_BROKER_URL:
            recipients = self.load_notification_users(success)
            subject = self.format_notification_subject(success)
            body = self.format_notification_body()

            self.run_list(recipients, send_mail)
示例#8
0
class ActionCommand(exec.ExecMixin, CommandMixin('log'),
                    CommandMixin('schedule'), CommandMixin('notification'),
                    base.BaseCommand):
    lock = threading.Lock()

    @classmethod
    def generate(cls, command, generator):
        # Override in subclass if needed
        pass

    def __init__(self, name, parent=None):
        super().__init__(name, parent)

        self.disconnected = False
        self.log_result = True
        self.notification_messages = []

        self.action_result = self.get_action_result()

    def disable_logging(self):
        with self.lock:
            self.log_result = False
        self.log_status(self._log.model.STATUS_UNTRACKED)

    def disconnect(self):
        with self.lock:
            self.disconnected = True

    def connected(self):
        with self.lock:
            return not self.disconnected

    def queue(self, msg, log=True):
        data = super().queue(msg)
        if self.log_result:
            self.publish_message(data, include=log)
            self.log_message(data, log=log)

        self.notification_messages.append(
            self.raw_text(msg.format(disable_color=True)))
        self.action_result.add(msg)
        return data

    def get_action_result(self):
        return zimagi.command.CommandResponse()

    def display_header(self):
        return True

    def parse_base(self, addons=None):
        def action_addons():
            # Operations
            if settings.QUEUE_COMMANDS:
                self.parse_push_queue()
                self.parse_async_exec()

            if settings.QUEUE_COMMANDS or self.server_enabled():
                self.parse_worker_type()

            self.parse_local()

            if not settings.API_EXEC:
                self.parse_reverse_status()

            # Locking
            self.parse_lock_id()
            self.parse_lock_error()
            self.parse_lock_timeout()
            self.parse_lock_interval()
            self.parse_run_once()

            if self.server_enabled():
                # Scheduling
                self.parse_schedule()
                self.parse_schedule_begin()
                self.parse_schedule_end()

                # Notifications
                self.parse_command_notify()
                self.parse_command_notify_failure()

            if callable(addons):
                addons()

        super().parse_base(action_addons)

    def parse_worker_type(self):
        self.parse_variable('worker_type',
                            '--worker',
                            str,
                            'machine type of worker processor to run command',
                            value_label='MACHINE',
                            tags=['system'])

    @property
    def worker_type(self):
        return self.options.get('worker_type', None)

    def parse_push_queue(self):
        self.parse_flag(
            'push_queue',
            '--queue',
            "run command in the background and follow execution results",
            tags=['system'])

    @property
    def push_queue(self):
        return self.options.get('push_queue', False)

    def parse_async_exec(self):
        self.parse_flag(
            'async_exec',
            '--async',
            "return immediately letting command run in the background",
            tags=['system'])

    @property
    def async_exec(self):
        return self.options.get('async_exec', False)

    @property
    def background_process(self):
        return settings.QUEUE_COMMANDS and (self.push_queue or self.async_exec)

    def parse_local(self):
        self.parse_flag('local',
                        '--local',
                        "force command to run in local environment",
                        tags=['system'])

    @property
    def local(self):
        return self.options.get('local', False)

    def parse_reverse_status(self):
        self.parse_flag('reverse_status',
                        '--reverse-status',
                        "reverse exit status of command (error on success)",
                        tags=['system'])

    @property
    def reverse_status(self):
        return self.options.get('reverse_status', False)

    def parse_lock_id(self):
        self.parse_variable(
            'lock_id',
            '--lock',
            str,
            'command lock id to prevent simultanious duplicate execution',
            value_label='UNIQUE_NAME',
            tags=['lock'])

    @property
    def lock_id(self):
        return self.options.get('lock_id', None)

    def parse_lock_error(self):
        self.parse_flag(
            'lock_error',
            '--lock-error',
            'raise an error and abort if commmand lock can not be established',
            tags=['lock'])

    @property
    def lock_error(self):
        return self.options.get('lock_error', False)

    def parse_lock_timeout(self):
        self.parse_variable('lock_timeout',
                            '--lock-timeout',
                            int,
                            'command lock wait timeout in seconds',
                            value_label='SECONDS',
                            default=600,
                            tags=['lock'])

    @property
    def lock_timeout(self):
        return self.options.get('lock_timeout', 600)

    def parse_lock_interval(self):
        self.parse_variable('lock_interval',
                            '--lock-interval',
                            int,
                            'command lock check interval in seconds',
                            value_label='SECONDS',
                            default=2,
                            tags=['lock'])

    @property
    def lock_interval(self):
        return self.options.get('lock_interval', 2)

    def parse_run_once(self):
        self.parse_flag(
            'run_once',
            '--run-once',
            "persist the lock id as a state flag to prevent duplicate executions",
            tags=['lock'])

    @property
    def run_once(self):
        return self.options.get('run_once', False)

    def confirm(self):
        # Override in subclass
        pass

    def prompt(self):
        def _standard_prompt(parent, split=False):
            try:
                self.info('-' * self.display_width)
                value = input("Enter {}{}: ".format(parent,
                                                    " (csv)" if split else ""))
                if split:
                    value = re.split('\s*,\s*', value)
            except Exception as error:
                self.error("User aborted", 'abort')

            return value

        def _hidden_verify_prompt(parent, split=False):
            try:
                self.info('-' * self.display_width)
                value1 = getpass.getpass(prompt="Enter {}{}: ".format(
                    parent, " (csv)" if split else ""))
                value2 = getpass.getpass(prompt="Re-enter {}{}: ".format(
                    parent, " (csv)" if split else ""))
            except Exception as error:
                self.error("User aborted", 'abort')

            if value1 != value2:
                self.error("Prompted {} values do not match".format(parent))

            if split:
                value1 = re.split('\s*,\s*', value1)

            return value1

        def _option_prompt(parent, option, top_level=False):
            any_override = False

            if isinstance(option, dict):
                for name, value in option.items():
                    override, value = _option_prompt(parent + [str(name)],
                                                     value)
                    if override:
                        option[name] = value
                        any_override = True

            elif isinstance(option, (list, tuple)):
                process_list = True

                if len(option) == 1:
                    override, value = _option_prompt(parent, option[0])
                    if isinstance(option[0], str) and option[0] != value:
                        option.extend(re.split('\s*,\s*', value))
                        option.pop(0)
                        process_list = False
                        any_override = True

                if process_list:
                    for index, value in enumerate(option):
                        override, value = _option_prompt(
                            parent + [str(index)], value)
                        if override:
                            option[index] = value
                            any_override = True

            elif isinstance(option, str):
                parent = " ".join(parent).replace("_", " ")

                if option == '+prompt+':
                    option = _standard_prompt(parent)
                    any_override = True
                elif option == '++prompt++':
                    option = _standard_prompt(parent, True)
                    any_override = True
                elif option == '+private+':
                    option = _hidden_verify_prompt(parent)
                    any_override = True
                elif option == '++private++':
                    option = _hidden_verify_prompt(parent, True)
                    any_override = True

            return any_override, option

        for name, value in self.options.export().items():
            override, value = _option_prompt([name], value, True)
            if override:
                self.options.add(name, value)

    def exec(self):
        # Override in subclass
        pass

    def exec_local(self, name, options=None, task=None, primary=False):
        if not options:
            options = {}

        command = self.manager.index.find_command(name, self)
        command.mute = self.mute

        options = command.format_fields(copy.deepcopy(options))
        options.setdefault('debug', self.debug)
        options.setdefault('no_parallel', self.no_parallel)
        options.setdefault('no_color', self.no_color)
        options.setdefault('display_width', self.display_width)
        options['local'] = not self.server_enabled() or self.local

        log_key = options.pop('_log_key', None)
        wait_keys = options.pop('_wait_keys', None)

        command.wait_for_tasks(wait_keys)
        command.set_options(options)
        return command.handle(options,
                              primary=primary,
                              task=task,
                              log_key=log_key)

    def exec_remote(self, host, name, options=None, display=True):
        if not options:
            options = {}

        command = self.manager.index.find_command(name, self)
        command.mute = self.mute
        success = True

        options = {
            key: options[key]
            for key in options
            if key not in ('no_color', 'environment_host', 'local', 'version',
                           'reverse_status')
        }
        options['environment_host'] = self.environment_host
        options.setdefault('debug', self.debug)
        options.setdefault('no_parallel', self.no_parallel)
        options.setdefault('display_width', self.display_width)

        command.set_options(options)
        command.log_init(options)

        def message_callback(message):
            message = self.create_message(message.render(), decrypt=False)

            if (display and self.verbosity > 0) or isinstance(
                    message, messages.ErrorMessage):
                message.display(debug=self.debug,
                                disable_color=self.no_color,
                                width=self.display_width)
            command.queue(message)

        try:
            api = host.command_api(options_callback=command.preprocess_handler,
                                   message_callback=message_callback)
            response = api.execute(name, **options)
            command.postprocess_handler(response)

            if response.aborted:
                success = False
                raise CommandError()
        finally:
            command.log_status(success, True)

        return response

    def preprocess(self, options):
        # Override in subclass
        pass

    def preprocess_handler(self, options, primary=False):
        self.start_profiler('preprocess', primary)
        self.preprocess(options)
        self.stop_profiler('preprocess', primary)

    def postprocess(self, response):
        # Override in subclass
        pass

    def postprocess_handler(self, response, primary=False):
        if not response.aborted:
            self.start_profiler('postprocess', primary)
            self.postprocess(response)
            self.stop_profiler('postprocess', primary)

    def handle(self, options, primary=False, task=None, log_key=None):
        reverse_status = self.reverse_status and not self.background_process

        try:
            width = self.display_width
            env = self.get_env()
            host = self.get_host()
            success = True

            log_key = self.log_init(self.options.export(),
                                    task=task,
                                    log_key=log_key)
            if primary:
                self.check_abort()
                self._register_signal_handlers()

            if primary and (settings.CLI_EXEC or settings.SERVICE_INIT):
                self.info("-" * width, log=False)

            if not self.local and host and \
                (settings.CLI_EXEC or host.name != settings.DEFAULT_HOST_NAME) and \
                self.server_enabled() and self.remote_exec():

                if primary and self.display_header(
                ) and self.verbosity > 1 and not task:
                    self.data("> env ({})".format(self.key_color(host.host)),
                              env.name,
                              'environment',
                              log=False)

                if primary and settings.CLI_EXEC and not task:
                    self.prompt()
                    self.confirm()

                self.exec_remote(host,
                                 self.get_full_name(),
                                 options,
                                 display=True)
            else:
                if not self.check_execute():
                    self.error(
                        "User {} does not have permission to execute command: {}"
                        .format(self.active_user.name, self.get_full_name()))

                if primary and self.display_header(
                ) and self.verbosity > 1 and not task:
                    self.data('> env', env.name, 'environment', log=False)

                if primary and not task:
                    if settings.CLI_EXEC:
                        self.prompt()
                        self.confirm()

                    if settings.CLI_EXEC or settings.SERVICE_INIT:
                        self.info("=" * width, log=False)
                        self.data("> {}".format(
                            self.key_color(self.get_full_name())),
                                  log_key,
                                  'log_key',
                                  log=False)
                        self.info("-" * width, log=False)
                try:
                    self.preprocess_handler(self.options, primary)
                    if not self.set_periodic_task(
                    ) and not self.set_queue_task(log_key):
                        self.start_profiler('exec', primary)
                        self.run_exclusive(self.lock_id,
                                           self.exec,
                                           error_on_locked=self.lock_error,
                                           timeout=self.lock_timeout,
                                           interval=self.lock_interval,
                                           run_once=self.run_once)
                        self.stop_profiler('exec', primary)

                except Exception as error:
                    success = False
                    raise error
                finally:
                    self.postprocess_handler(self.action_result, primary)

                    success = not success if self.reverse_status else success
                    if not self.background_process:
                        self.log_status(success, True)

                    if primary:
                        self.send_notifications(success)

        except Exception as error:
            if reverse_status:
                return log_key
            raise error

        finally:
            if not self.background_process:
                self.publish_exit()

            if primary:
                self.flush()
                self.manager.cleanup()

        if reverse_status:
            raise ReverseStatusError()

        return log_key

    def _exec_wrapper(self, options):
        try:
            width = self.display_width
            log_key = self.log_init(options)
            success = True

            self.check_abort()

            if self.display_header() and self.verbosity > 1:
                self.info("=" * width)
                self.data("> {}".format(self.get_full_name()), log_key,
                          'log_key')
                self.data("> active user", self.active_user.name,
                          'active_user')
                self.info("-" * width)

            if not self.set_periodic_task() and not self.set_queue_task(
                    log_key):
                self.run_exclusive(self.lock_id,
                                   self.exec,
                                   error_on_locked=self.lock_error,
                                   timeout=self.lock_timeout,
                                   interval=self.lock_interval,
                                   run_once=self.run_once)

        except Exception as e:
            success = False

            if not isinstance(e, (CommandError, CommandAborted)):
                self.error(e,
                           terminate=False,
                           traceback=display.format_exception_info())
        finally:
            try:
                self.send_notifications(success)
                self.log_status(success, True)

            except Exception as e:
                self.error(e,
                           terminate=False,
                           traceback=display.format_exception_info())

            finally:
                self.publish_exit()
                self.manager.cleanup()
                self.flush()

    def handle_api(self, options):
        self._register_signal_handlers()

        logger.debug("Running API command: {}\n\n{}".format(
            self.get_full_name(), yaml.dump(options)))

        action = threading.Thread(target=self._exec_wrapper, args=(options, ))
        action.start()

        logger.debug("Command thread started: {}".format(self.get_full_name()))

        try:
            while True:
                self.sleep(0.25)
                logger.debug("Checking messages")

                for data in iter(self.messages.get, None):
                    logger.debug("Receiving data: {}".format(data))

                    msg = self.create_message(data, decrypt=False)
                    package = msg.to_package()
                    yield package

                if not action.is_alive():
                    logger.debug("Command thread is no longer active")
                    break
        except Exception as e:
            logger.warning("Command transport exception: {}".format(e))
            raise e
        finally:
            logger.debug("User disconnected")
            self.disconnect()
示例#9
0
文件: module.py 项目: zimagi/zimagi
class ModuleMixin(CommandMixin('module')):

    template_lock = threading.Lock()

    def provision_template(self,
                           module,
                           package_name,
                           template_fields,
                           display_only=False):
        module.initialize(self)
        self.manager.load_templates()

        index, template_fields = self._load_package(package_name,
                                                    template_fields,
                                                    display_only)
        with self.template_lock:
            self._store_template_map(module, index, template_fields,
                                     display_only)
            self._create_template_directories(module, index, display_only)

        self._run_package_commands(index, display_only)

    def _load_package(self, package_name, template_fields, display_only):
        template_fields = self._prepare_template_fields(
            self._load_package_index(package_name, template_fields),
            normalize_value(template_fields,
                            strip_quotes=True,
                            parse_json=True), display_only)
        return self._load_package_index(package_name,
                                        template_fields), template_fields

    def _load_package_index(self, package_name, template_fields):
        index_config = oyaml.safe_load(
            self._render_package_template(package_name, 'index.yml',
                                          template_fields))
        index_config['name'] = package_name
        return Collection(**index_config)

    def _prepare_template_fields(self, index, template_fields, display_only):
        processed_fields = OrderedDict()

        for field, info in index.variables.items():
            if info.get('required', False) and field not in template_fields:
                if display_only:
                    template_fields[field] = "<{{{}}}>".format(field)
                else:
                    raise TemplateException(
                        "Field {} is required for template {}".format(
                            field, index.name))

            if field in template_fields:
                processed_fields[field] = template_fields[field]

            elif info.get('default', None) is not None:
                processed_fields[field] = normalize_value(info['default'],
                                                          strip_quotes=True,
                                                          parse_json=True)
            else:
                processed_fields[field] = None

        return processed_fields

    def _store_template_map(self, module, index, template_fields,
                            display_only):
        self.notice('Template variables:')
        self.table([['Variable', 'Value', 'Help']] +
                   [[key, value, index.variables[key].get('help', 'NA')]
                    for key, value in template_fields.items()], 'variables')
        self.info('')

        for path, info in index.map.items():
            target = None

            if isinstance(info, str):
                target = info
                info = {}

            elif info.get('when', True) and 'target' in info:
                target = info['target']

            if target:
                path_components = os.path.split(
                    self.manager.get_module_path(module, target))
                target_path = os.path.join(*path_components)

                if info.get('template', True):
                    file_content = self._render_package_template(
                        index.name, path, template_fields)
                else:
                    file_content = load_file(
                        self.manager.get_template_path(index.name, path))

                if info.get('location', None) and path.endswith('.yml'):
                    file_data = normalize_value(load_yaml(target_path),
                                                strip_quotes=True,
                                                parse_json=True)
                    if not file_data:
                        file_data = {}

                    location = info['location'].split('.')
                    embed_data = normalize_value(oyaml.safe_load(file_content),
                                                 strip_quotes=True,
                                                 parse_json=True)
                    merge_data = {}
                    iter_data = merge_data

                    for index, key in enumerate(location):
                        if (index + 1) == len(location):
                            iter_data[key] = embed_data
                        else:
                            iter_data[key] = {}

                        iter_data = iter_data[key]

                    file_content = oyaml.dump(deep_merge(
                        file_data, merge_data))

                self.data('Path', path, 'path')
                self.data('Target', target, 'target')
                if info.get('location', None):
                    self.data('location', info['location'], 'location')
                self.notice('-' * self.display_width)
                if info.get('template', True):
                    self.info(file_content)
                self.info('')

                if not display_only:
                    create_dir(path_components[0])
                    save_file(target_path, file_content)

    def _render_package_template(self, package_name, file_path,
                                 template_fields):
        template = self.manager.template_engine.get_template(
            os.path.join(package_name, file_path))
        return template.render(**template_fields)

    def _create_template_directories(self, module, index, display_only):
        module_path = self.manager.get_module_path(module)

        if index.directories:
            self.notice('Directories:')
            self.info('')
            for directory in ensure_list(index.directories):
                module_dir = os.path.join(module_path, directory)
                self.info(module_dir)
                if not display_only:
                    create_dir(module_dir)
            self.info('')

    def _run_package_commands(self, index, display_only):
        if index.commands:
            for command in ensure_list(index.commands):
                if isinstance(command, str):
                    command = re.split(r'\s+', command)

                self.data('Command', " ".join(command), 'command')
                self.notice('-' * self.display_width)
                if not display_only:
                    if not self.sh(command):
                        raise TemplateException(
                            "Template package command execution failed: {}".
                            format(command))
            self.info('')
示例#10
0
class ScheduleMixin(CommandMixin('schedule')):

    def get_schedule_from_representation(self, representation):
        schedule = self.get_interval_schedule(representation)

        if not schedule:
            schedule = self.get_datetime_schedule(representation)
        if not schedule:
            schedule = self.get_crontab_schedule(representation)

        if not schedule:
            self.error("'{}' is not a valid schedule format.  See --help for more information".format(representation))

        return schedule

    def normalize_schedule_time(self, time_string):
        return make_aware(datetime.strptime(time_string, "%Y-%m-%d %H:%M:%S"))


    def set_periodic_task(self):
        schedule = self.schedule

        if schedule:
            begin = self.schedule_begin
            end = self.schedule_end

            schedule_map = {
                'task_interval': 'interval',
                'task_crontab': 'crontab',
                'task_datetime': 'clocked'
            }
            options = self.options.export()
            options['_user'] = self.active_user.name
            task = {
                schedule_map[schedule.facade.name]: schedule,
                'task': 'zimagi.command.exec',
                'user': self.active_user,
                'args': json.dumps([self.get_full_name()]),
                'kwargs': json.dumps(options)
            }
            if begin:
                task['start_time'] = begin
            if end:
                task['expires'] = end

            self._scheduled_task.store(self.get_schedule_name(), **task)

            self.success("Task '{}' has been scheduled to execute periodically".format(self.get_full_name()))
            return True

        return False


    def set_queue_task(self):
        if self.push_queue:
            options = self.options.export()
            options['_user'] = self.active_user.name
            exec_command.delay(self.get_full_name(), **options)
            self.success("Task '{}' has been pushed to the queue to execute in the background".format(self.get_full_name()))
            return True

        return False


    def get_schedule_name(self):
        return "{}:{}{}".format(
            self.get_full_name().replace(' ', '-'),
            datetime.now().strftime("%Y%m%d%H%M%S"),
            random.SystemRandom().choice(string.ascii_lowercase)
        )


    def get_interval_schedule(self, representation):
        interval = self._task_interval.model
        schedule = None
        period_map = {
            'D': interval.DAYS,
            'H': interval.HOURS,
            'M': interval.MINUTES,
            'S': interval.SECONDS
        }

        match = re.match(r'^(\d+)([DHMS])$', representation, flags = re.IGNORECASE)
        if match:
            schedule, created = self._task_interval.store(representation,
                every = match.group(1),
                period = period_map[match.group(2).upper()],
            )
        return schedule

    def get_crontab_schedule(self, representation):
        schedule = None

        match = re.match(r'^([\*\d\-\/\,]+) ([\*\d\-\/\,]+) ([\*\d\-\/\,]+) ([\*\d\-\/\,]+) ([\*\d\-\/\,]+)$', representation)
        if match:
            schedule, created = self._task_crontab.store(representation,
                minute = match.group(1),
                hour = match.group(2),
                day_of_week = match.group(3),
                day_of_month = match.group(4),
                month_of_year = match.group(5)
            )
        return schedule

    def get_datetime_schedule(self, representation):
        schedule = None

        match = re.match(r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$', representation)
        if match:
            schedule, created = self._task_datetime.store(representation,
                clocked_time = make_aware(datetime.strptime(representation, "%Y-%m-%d %H:%M:%S")),
            )
        return schedule
示例#11
0
class DatabaseMixin(CommandMixin('db')):
    @property
    def db(self):
        if not getattr(self, '_cached_db_manager', None):
            self._cached_db_manager = manager.DatabaseManager()
        return self._cached_db_manager
示例#12
0
class RendererMixin(CommandMixin('config'), BaseMixin):
    def render(self, facade, fields, queryset):
        fields = list(fields)
        data = [fields]

        for instance in queryset:
            instance = self.get_instance_by_id(facade,
                                               instance.get_id(),
                                               required=False)
            if instance:
                record = []

                for field in fields:
                    display_method = getattr(
                        facade, "get_field_{}_display".format(field), None)
                    value = getattr(instance, field, None)

                    if display_method and callable(display_method):
                        value = display_method(instance, value, True)

                    elif isinstance(value, datetime.datetime):
                        value = localtime(value).strftime(
                            "%Y-%m-%d %H:%M:%S %Z")

                    record.append(value)

                data.append(record)

        return data

    def get_default_fields(self, facade):
        info = OrderedDict()
        fields = []
        key = None

        for field in facade.field_instances:
            if field.name == facade.key():
                key = field
            elif field.name != facade.pk:
                fields.append(field)

        info[key.name] = key.verbose_name.title()

        for field in sorted(fields, key=lambda x: x.name):
            if not facade.check_field_related(field):
                info[field.name] = field.verbose_name.title()

        return info

    def get_related_fields(self, facade):
        info = {}
        for field_name, field_info in facade.get_all_relations().items():
            info[field_name] = field_info['label']
        return info

    def get_config_fields(self, facade, config_name, allowed_fields=None):
        default_fields = self.get_default_fields(facade)

        if allowed_fields:
            overrides = allowed_fields
        else:
            overrides = self.get_config(config_name, None)

        if overrides:
            fields = OrderedDict()
            fields[facade.key()] = default_fields[facade.key()]

            for field_name in data.ensure_list(overrides):
                if field_name == facade.pk:
                    fields[field_name] = 'ID'
                else:
                    if field_name in default_fields:
                        fields[field_name] = default_fields[field_name]

            return fields
        else:
            return default_fields

    def get_config_relations(self, facade, config_name, allowed_fields=None):
        related_fields = self.get_related_fields(facade)

        if allowed_fields:
            overrides = allowed_fields
        else:
            overrides = self.get_config(config_name, None)

        if overrides:
            fields = OrderedDict()
            for field_name in data.ensure_list(overrides):
                if field_name in related_fields:
                    fields[field_name] = related_fields[field_name]
            return fields
        else:
            return related_fields

    def get_list_fields(self, facade, allowed_fields=None):
        config_name = "{}_list_fields".format(facade.name)
        return self.get_config_fields(facade, config_name, allowed_fields)

    def get_list_relations(self, facade, allowed_fields=None):
        config_name = "{}_list_fields".format(facade.name)
        return self.get_config_relations(facade, config_name, allowed_fields)

    def get_display_fields(self, facade, allowed_fields=None):
        config_name = "{}_display_fields".format(facade.name)
        return self.get_config_fields(facade, config_name, allowed_fields)

    def get_display_relations(self, facade, allowed_fields=None):
        config_name = "{}_display_fields".format(facade.name)
        return self.get_config_relations(facade, config_name, allowed_fields)

    def render_list_fields(self, facade, allowed_fields=None):
        fields = []
        labels = []

        for name, label in self.get_list_fields(facade,
                                                allowed_fields).items():
            label = self.format_label(label)
            fields.append(name)
            labels.append(label)

        return (fields, labels)

    def render_relation_overview(self, facade, name, instances):
        facade_index = facade.manager.index.get_facade_index()

        if name not in facade_index:
            return []

        facade = facade_index[name]
        relations = facade.get_all_relations()
        fields, labels = self.render_list_fields(facade)
        field_relations = []

        for name, label in self.get_list_relations(facade).items():
            label = self.format_label(label)
            field_relations.append(name)
            labels.append(label)

        data = self.render(facade, [facade.pk] + fields,
                           facade.filter(**{'id__in': instances.keys()}))
        data[0] = [self.header_color(x) for x in labels]
        if len(data) > 1:
            for index, info in enumerate(data[1:]):
                id = self.raw_text(info.pop(0))
                for field_name in field_relations:
                    field_info = relations[field_name]
                    items = []
                    value = getattr(instances[id], field_name)

                    if field_info['multiple']:
                        for sub_instance in value.all():
                            items.append(self.relation_color(
                                str(sub_instance)))
                    else:
                        items.append(self.relation_color(str(value)))

                    info.append("\n".join(items))
        else:
            data = []
        return data

    def render_list(self, facade, filters=None, allowed_fields=None):
        if not filters:
            filters = {}

        relations = facade.get_all_relations()
        data = []
        fields, labels = self.render_list_fields(facade, allowed_fields)
        field_relations = []

        for name, label in self.get_list_relations(facade,
                                                   allowed_fields).items():
            label = self.format_label(label)
            field_relations.append(name)
            labels.append(label)

        if facade.count(**filters):
            data = self.render(facade, [facade.pk] + fields,
                               facade.filter(**filters))
            id_index = data[0].index(facade.pk)
            key_index = (data[0].index(facade.key()) - 1)

            data[0] = [self.header_color(x) for x in labels]

            for index, info in enumerate(data[1:]):
                id = self.raw_text(info.pop(id_index))
                instance = self.get_instance_by_id(facade, id, required=False)
                info[key_index] = info[key_index]

                for field_name in field_relations:
                    field_info = relations[field_name]
                    items = []
                    value = getattr(instance, field_name)

                    if field_info['multiple']:
                        for sub_instance in value.all():
                            items.append(self.relation_color(
                                str(sub_instance)))
                    else:
                        items.append(self.relation_color(str(value)))

                    info.append("\n".join(items))

        return data

    def render_display(self, facade, name, allowed_fields=None):
        if isinstance(name, BaseModel):
            instance = name
        else:
            instance = self.get_instance(facade, name, required=False)

        relations = facade.get_all_relations()
        data = []

        if instance:
            for name, label in self.get_display_fields(facade,
                                                       allowed_fields).items():
                label = self.format_label(label)
                display_method = getattr(facade,
                                         "get_field_{}_display".format(name),
                                         None)
                value = getattr(instance, name, None)

                if display_method and callable(display_method):
                    value = display_method(instance, value, False)
                else:
                    if isinstance(value, datetime.datetime):
                        value = localtime(value).strftime(
                            "%Y-%m-%d %H:%M:%S %Z")
                    else:
                        value = str(value)

                data.append((self.header_color(label), value))

            for name, label in self.get_display_relations(
                    facade, allowed_fields).items():
                label = self.format_label(label)
                field_info = relations[name]
                label = self.header_color(label)
                value = getattr(instance, name)

                if field_info['multiple']:
                    instances = {x.get_id(): x for x in value.all()}
                    relation_data = self.render_relation_overview(
                        facade, field_info['name'], instances)
                    if relation_data:
                        value = display.format_data(relation_data,
                                                    width=self.display_width)
                        data.append((label, value + "\n"))
                else:
                    data.append(
                        (label, self.relation_color(str(value)) + "\n"))
        else:
            self.error("{} {} does not exist".format(facade.name.title(),
                                                     name))

        return data

    def format_label(self, label):
        return "\n".join(label.split(' '))
示例#13
0
class EnvironmentMixin(CommandMixin('environment')):
    @property
    def curr_env_name(self):
        return self._environment.get_env()

    def get_env(self, host_name=None):
        name = self._environment.get_env()
        env = self.get_instance(self._environment, name, required=False)
        host = None

        if not host_name:
            host_name = self.environment_host

        if not settings.API_EXEC or host_name != settings.DEFAULT_HOST_NAME:
            host = self._host.retrieve(host_name)

        env.host = host.host if host else None
        env.port = host.port if host else None
        env.user = host.user if host else None
        env.token = host.token if host else None
        return env

    def create_env(self, name, **fields):
        env = self._environment.create(name)
        host = self._host.create('temp')

        env.host = host.host
        env.port = host.port
        env.user = host.user
        env.token = host.token

        for field, value in fields.items():
            setattr(env, field, value)
        return env

    def set_env(self, name=None, repo=None, image=None):
        self._environment.set_env(name, repo, image)
        self.success("Successfully updated current environment")

    def update_env_host(self, **fields):
        name = fields.pop('name', None)
        if not name:
            name = self.environment_host

        host = self._host.retrieve(name)
        if not host:
            host = self._host.create(name, **fields)
        else:
            for field, value in fields.items():
                setattr(host, field, value)
        host.save()

    def delete_env(self):
        self.exec_local('module clear')
        self._environment.delete_env()
        self.success("Successfully removed environment")

    def get_state(self, name, default=None):
        instance = self.get_instance(self._state, name, required=False)
        if instance:
            return instance.value
        return default

    def set_state(self, name, value=None):
        self._state.store(name, value=value)

    def delete_state(self, name=None, default=None):
        value = self.get_state(name, default)
        self._state.delete(name)
        return value

    def clear_state(self):
        self._state.clear()
示例#14
0
文件: base.py 项目: zimagi/zimagi
class BaseCommand(TerminalMixin, renderer.RendererMixin, CommandMixin('user'),
                  CommandMixin('environment'), CommandMixin('group'),
                  CommandMixin('config'), CommandMixin('module')):
    def __init__(self, name, parent=None):
        self.facade_index = {}

        self.name = name
        self.parent_instance = parent
        self.exec_parent = None

        self.confirmation_message = 'Are you absolutely sure?'
        self.messages = queue.Queue()
        self.parent_messages = None
        self.mute = False

        self.schema = {}
        self.parser = None
        self.options = options.AppOptions(self)
        self.option_lock = threading.Lock()
        self.option_map = {}
        self.option_defaults = {}
        self.descriptions = help.CommandDescriptions()

        self.profilers = {}

        self.signal_list = [signal.SIGHUP, signal.SIGINT, signal.SIGTERM]
        self.signal_handlers = {}

        super().__init__()

    def _signal_handler(self, sig, stack_frame):
        for lock_id in settings.MANAGER.index.get_locks():
            check_mutex(lock_id, force_remove=True).__exit__()

        for sig, handler in self.signal_handlers.items():
            signal.signal(sig, handler)

        os.kill(os.getpid(), sig)

    def _register_signal_handlers(self):
        for sig in self.signal_list:
            self.signal_handlers[sig] = (signal.signal(
                sig, self._signal_handler) or signal.SIG_DFL)

    def sleep(self, seconds):
        time.sleep(seconds)

    @property
    def manager(self):
        return settings.MANAGER

    @property
    def spec(self):
        return self.manager.get_spec(['command'] +
                                     self.get_full_name().split())

    @property
    def base_path(self):
        env = self.get_env()
        return os.path.join(settings.MODULE_BASE_PATH, env.name)

    @property
    def module_path(self):
        return "{}/{}".format(self.base_path, self.spec['_module'])

    def get_path(self, path):
        return os.path.join(self.module_path, path)

    def queue(self, msg, log=True):
        def _queue_parents(command, data):
            if command.parent_messages:
                command.parent_messages.put(data)
            if command.parent_instance:
                _queue_parents(command.parent_instance, data)

        data = msg.render()
        logger.debug("Adding command queue message: {}".format(data))

        self.messages.put(data)
        _queue_parents(self, data)
        return data

    def flush(self):
        logger.debug("Flushing command queue")
        self.messages.put(None)

    def create_message(self, data, decrypt=True):
        return messages.AppMessage.get(data,
                                       decrypt=decrypt,
                                       user=self.active_user.name)

    def get_messages(self, flush=True):
        messages = []

        if flush:
            self.flush()

        for message in iter(self.messages.get, None):
            messages.append(message)
        return messages

    def add_schema_field(self, name, field, optional=True, tags=None):
        if tags is None:
            tags = []

        self.schema[name] = Field(name=name,
                                  location='form',
                                  required=not optional,
                                  schema=field_to_schema(field),
                                  type=type(field).__name__.lower(),
                                  tags=tags)

    def get_schema(self):
        return schema.CommandSchema(
            list(self.schema.values()),
            re.sub(r'\s+', ' ', self.get_description(False)))

    def create_parser(self):
        def display_error(message):
            self.warning(message + "\n")
            self.print_help()
            self.exit(1)

        epilog = self.get_epilog()
        if epilog:
            epilog = "\n".join(wrap_page(epilog))

        parser = CommandParser(prog=self.command_color('{} {}'.format(
            settings.APP_NAME, self.get_full_name())),
                               description="\n".join(
                                   wrap_page(self.get_description(False),
                                             init_indent=' ',
                                             init_style=self.header_color(),
                                             indent='  ')),
                               epilog=epilog,
                               formatter_class=argparse.RawTextHelpFormatter,
                               called_from_command_line=True)
        parser.error = display_error

        self._user._ensure(self)
        self.add_arguments(parser)
        return parser

    def add_arguments(self, parser):
        self.parser = parser
        self.parse_base()

    def parse(self):
        # Override in subclass
        pass

    def parse_base(self, addons=None):
        self.option_map = {}

        if not self.parse_passthrough():
            # Display
            self.parse_verbosity()
            self.parse_debug()
            self.parse_display_width()
            self.parse_no_color()

            if not settings.API_EXEC:
                # Operations
                self.parse_version()

                if self.server_enabled():
                    self.parse_environment_host()

            # Operations
            self.parse_no_parallel()

            if addons and callable(addons):
                addons()

            self.parse()

    def parse_passthrough(self):
        return False

    def parse_environment_host(self):
        self.parse_variable('environment_host',
                            '--host',
                            str,
                            "environment host name",
                            value_label='NAME',
                            default=settings.DEFAULT_HOST_NAME,
                            tags=['system'])

    @property
    def environment_host(self):
        return self.options.get('environment_host', settings.DEFAULT_HOST_NAME)

    def parse_verbosity(self):
        self.parse_variable(
            'verbosity',
            '--verbosity',
            int,
            "verbosity level; 0=silent, 1=minimal, 2=normal, 3=verbose",
            value_label='LEVEL',
            default=2,
            choices=(0, 1, 2, 3),
            tags=['display'])

    @property
    def verbosity(self):
        return self.options.get('verbosity', 2)

    def parse_version(self):
        self.parse_flag('version',
                        '--version',
                        "show environment runtime version information",
                        tags=['system'])

    def parse_display_width(self):
        columns, rows = shutil.get_terminal_size(
            fallback=(settings.DISPLAY_WIDTH, 25))
        self.parse_variable('display_width',
                            '--display-width',
                            int,
                            "CLI display width",
                            value_label='WIDTH',
                            default=columns,
                            tags=['display'])

    @property
    def display_width(self):
        return self.options.get('display_width', Runtime.width())

    def parse_no_color(self):
        self.parse_flag('no_color',
                        '--no-color',
                        "don't colorize the command output",
                        tags=['display'])

    @property
    def no_color(self):
        return self.options.get('no_color', not Runtime.color())

    def parse_debug(self):
        self.parse_flag('debug',
                        '--debug',
                        'run in debug mode with error tracebacks',
                        tags=['display'])

    @property
    def debug(self):
        return self.options.get('debug', Runtime.debug())

    def parse_no_parallel(self):
        self.parse_flag('no_parallel',
                        '--no-parallel',
                        'disable parallel processing',
                        tags=['system'])

    @property
    def no_parallel(self):
        return self.options.get('no_parallel', not Runtime.parallel())

    def interpolate_options(self):
        return True

    def server_enabled(self):
        return True

    def remote_exec(self):
        return self.server_enabled()

    def groups_allowed(self):
        return False

    def get_version(self):
        if not getattr(self, '_version'):
            self._version = load_file(
                os.path.join(self.manager.app_dir, 'VERSION'))
        return self._version

    def get_priority(self):
        return 1

    def get_parent_name(self):
        if self.parent_instance and self.parent_instance.name != 'root':
            return self.parent_instance.get_full_name()
        return ''

    def get_full_name(self):
        return "{} {}".format(self.get_parent_name(), self.name).strip()

    def get_id(self):
        return ".".join(self.get_full_name().split(' '))

    def get_description(self, overview=False):
        return self.descriptions.get(self.get_full_name(), overview)

    def get_epilog(self):
        return None

    @property
    def active_user(self):
        return self._user.active_user if getattr(self, '_user', None) else None

    def check_execute(self, user=None):
        groups = self.groups_allowed()
        user = self.active_user if user is None else user

        if not user:
            return False
        if user.name == settings.ADMIN_USER:
            return True

        if not groups:
            return True

        return user.env_groups.filter(name__in=groups).exists()

    def check_access(self, instance, reset=False):
        return self.check_access_by_groups(instance,
                                           instance.access_groups(reset))

    def check_access_by_groups(self, instance, groups):
        user_groups = [Roles.admin]

        if 'public' in groups:
            return True
        elif self.active_user is None:
            return False

        if not groups or self.active_user.name == settings.ADMIN_USER:
            return True

        for group in groups:
            if isinstance(group, (list, tuple)):
                user_groups.extend(list(group))
            else:
                user_groups.append(group)

        if len(user_groups):
            if not self.active_user.env_groups.filter(
                    name__in=user_groups).exists():
                self.warning(
                    "Operation {} {} {} access requires at least one of the following roles in environment: {}"
                    .format(self.get_full_name(), instance.facade.name,
                            instance.name, ", ".join(user_groups)))
                return False

        return True

    def get_provider(self, type, name, *args, **options):
        type_components = type.split(':')
        type = type_components[0]
        subtype = type_components[1] if len(type_components) > 1 else None

        base_provider = self.manager.index.get_plugin_base(type)
        providers = self.manager.index.get_plugin_providers(type, True)

        if name is None or name in ('help', 'base'):
            provider_class = base_provider
        elif name in providers.keys():
            provider_class = providers[name]
        else:
            self.error("Plugin {} provider {} not supported".format(
                type, name))

        try:
            return provider_class(type, name, self, *args,
                                  **options).context(subtype, self.test)
        except Exception as e:
            self.error("Plugin {} provider {} error: {}".format(type, name, e))

    def print_help(self):
        parser = self.create_parser()
        self.info(parser.format_help())

    def message(self,
                msg,
                mutable=True,
                silent=False,
                log=True,
                verbosity=None):
        if mutable and self.mute:
            return

        if verbosity is None:
            verbosity = self.verbosity

        if not silent and (verbosity > 0 or msg.is_error()):
            self.queue(msg, log=log)

            if settings.CLI_EXEC or settings.SERVICE_INIT or self.debug:
                display_options = {
                    'debug': self.debug,
                    'disable_color': self.no_color,
                    'width': self.display_width
                }
                if msg.is_error():
                    display_options['traceback'] = (verbosity > 1)

                msg.display(**display_options)

    def info(self, message, name=None, prefix=None, log=True):
        self.message(messages.InfoMessage(
            str(message),
            name=name,
            prefix=prefix,
            silent=False,
            user=self.active_user.name if self.active_user else None),
                     log=log)

    def data(self,
             label,
             value,
             name=None,
             prefix=None,
             silent=False,
             log=True):
        self.message(messages.DataMessage(
            str(label),
            value,
            name=name,
            prefix=prefix,
            silent=silent,
            user=self.active_user.name if self.active_user else None),
                     log=log)

    def silent_data(self, name, value, log=True):
        self.data(name, value, name=name, silent=True, log=log)

    def notice(self, message, name=None, prefix=None, log=True):
        self.message(messages.NoticeMessage(
            str(message),
            name=name,
            prefix=prefix,
            silent=False,
            user=self.active_user.name if self.active_user else None),
                     log=log)

    def success(self, message, name=None, prefix=None, log=True):
        self.message(messages.SuccessMessage(
            str(message),
            name=name,
            prefix=prefix,
            silent=False,
            user=self.active_user.name if self.active_user else None),
                     log=log)

    def warning(self, message, name=None, prefix=None, log=True):
        self.message(messages.WarningMessage(
            str(message),
            name=name,
            prefix=prefix,
            silent=False,
            user=self.active_user.name if self.active_user else None),
                     mutable=False,
                     log=log)

    def error(self,
              message,
              name=None,
              prefix=None,
              terminate=True,
              traceback=None,
              error_cls=CommandError,
              silent=False):
        msg = messages.ErrorMessage(
            str(message),
            traceback=traceback,
            name=name,
            prefix=prefix,
            silent=silent,
            user=self.active_user.name if self.active_user else None)
        if not traceback:
            msg.traceback = format_traceback()

        self.message(msg, mutable=False, silent=silent)
        if terminate:
            raise error_cls(str(message))

    def table(self,
              data,
              name=None,
              prefix=None,
              silent=False,
              row_labels=False,
              log=True):
        self.message(messages.TableMessage(
            data,
            name=name,
            prefix=prefix,
            silent=silent,
            row_labels=row_labels,
            user=self.active_user.name if self.active_user else None),
                     log=log)

    def silent_table(self, name, data, log=True):
        self.table(data, name=name, silent=True, log=log)

    def confirmation(self, message=None):
        if not settings.API_EXEC and not self.force:
            if not message:
                message = self.confirmation_message

            confirmation = input("{} (type YES to confirm): ".format(message))

            if re.match(r'^[Yy][Ee][Ss]$', confirmation):
                return True

            self.error("User aborted", 'abort')

    def format_fields(self, data, process_func=None):
        fields = self.get_schema().get_fields()
        params = {}

        for key, value in data.items():
            if process_func and callable(process_func):
                key, value = process_func(key, value)

            if value is not None and value != '':
                if key in fields:
                    type = fields[key].type

                    if type in ('dictfield', 'listfield'):
                        params[key] = load_json(value)
                    elif type == 'booleanfield':
                        params[key] = load_json(value.lower())
                    elif type == 'integerfield':
                        params[key] = int(value)
                    elif type == 'floatfield':
                        params[key] = float(value)

                if key not in params:
                    params[key] = value
            else:
                params[key] = None

        return params

    def run_list(self, items, callback):
        results = Parallel.list(items,
                                callback,
                                disable_parallel=self.no_parallel)

        if results.aborted:
            for thread in results.errors:
                self.error(thread.error,
                           prefix="[ {} ]".format(thread.name),
                           traceback=thread.traceback,
                           terminate=False)
            raise ParallelError()

        return results

    def run_exclusive(self,
                      lock_id,
                      callback,
                      error_on_locked=False,
                      timeout=600,
                      interval=1,
                      run_once=False):
        if not lock_id:
            callback()
        else:
            start_time = time.time()
            current_time = start_time

            while (current_time - start_time) <= timeout:
                try:
                    state_id = "lock_{}".format(lock_id)
                    if run_once and self.get_state(state_id, None):
                        break

                    with check_mutex(lock_id):
                        callback()
                        if run_once:
                            self.set_state(state_id, current_time)
                        break

                except MutexError:
                    if error_on_locked:
                        self.error(
                            "Could not obtain lock for {}".format(lock_id))
                    if timeout == 0:
                        break

                except MutexTimeoutError:
                    logger.warning(
                        "Task {} completed but the lock timed out".format(
                            lock_id))
                    break

                self.sleep(interval)
                current_time = time.time()

    def get_profiler_path(self, name):
        base_path = os.path.join(settings.PROFILER_PATH, self.curr_env_name)
        pathlib.Path(base_path).mkdir(parents=True, exist_ok=True)
        return os.path.join(base_path,
                            "{}.{}.profile".format(self.get_id(), name))

    def start_profiler(self, name, check=True):
        if settings.COMMAND_PROFILE and settings.CLI_EXEC and check:
            if name not in self.profilers:
                self.profilers[name] = cProfile.Profile()
            self.profilers[name].enable()

    def stop_profiler(self, name, check=True):
        if settings.COMMAND_PROFILE and settings.CLI_EXEC and check:
            self.profilers[name].disable()

    def export_profiler_data(self):
        if settings.COMMAND_PROFILE and settings.CLI_EXEC:
            command_id = self.get_id()
            for name, profiler in self.profilers.items():
                profiler.dump_stats(self.get_profiler_path(name))

    def ensure_resources(self, reinit=False):
        for facade_index_name in sorted(self.facade_index.keys()):
            if facade_index_name not in ['00_user']:
                self.facade_index[facade_index_name]._ensure(self,
                                                             reinit=reinit)

    def set_option_defaults(self):
        self.parse_base()

        for key, value in self.option_defaults.items():
            self.options.add(key, value)

    def validate_options(self, options):
        allowed_options = list(self.option_map.keys())
        not_found = []

        for key, value in options.items():
            if key not in allowed_options:
                not_found.append(key)

        if not_found:
            self.error(
                "Requested command options not found: {}\n\nAvailable options: {}"
                .format(", ".join(not_found), ", ".join(allowed_options)))

    def set_options(self, options, primary=False):
        self.options.clear()

        if not primary or settings.API_EXEC:
            self.set_option_defaults()
            self.validate_options(options)

        host = options.pop('environment_host', None)
        if host:
            self.options.add('environment_host', host, False)

        for key, value in options.items():
            self.options.add(key, value)

    def bootstrap_ensure(self):
        return False

    def initialize_services(self):
        return True

    def bootstrap(self, options):
        Cipher.initialize()

        if options.get('debug', False):
            Runtime.debug(True)

        if options.get('no_parallel', False):
            Runtime.parallel(False)

        if options.get('no_color', False):
            Runtime.color(False)

        if options.get('display_width', False):
            Runtime.width(options.get('display_width'))

        self.init_environment()

        if self.bootstrap_ensure() and settings.CLI_EXEC:
            self._user._ensure(self)

        self.set_options(options, True)

        if self.bootstrap_ensure() and settings.CLI_EXEC:
            self.ensure_resources()

        if self.initialize_services():
            self.manager.initialize_services(settings.STARTUP_SERVICES)
        return self

    def handle(self, options, primary=False):
        # Override in subclass
        pass

    def run_from_argv(self, argv):
        parser = self.create_parser()
        args = argv[(len(self.get_full_name().split(' ')) + 1):]

        if not self.parse_passthrough():
            if '--version' in argv:
                return self.manager.index.find_command(
                    'version', main=True).run_from_argv([])

            elif '-h' in argv or '--help' in argv:
                return self.print_help()

            options = vars(parser.parse_args(args))
        else:
            options = {'args': args}

        try:
            self.bootstrap(options)
            self.handle(options, primary=True)
        finally:
            try:
                connections.close_all()
            except ImproperlyConfigured:
                pass

            self.export_profiler_data()
示例#15
0
文件: schedule.py 项目: zimagi/zimagi
class ScheduleMixin(CommandMixin('schedule')):
    def get_schedule_from_representation(self, representation):
        schedule = self.get_interval_schedule(representation)

        if not schedule:
            schedule = self.get_datetime_schedule(representation)
        if not schedule:
            schedule = self.get_crontab_schedule(representation)

        if not schedule:
            self.error(
                "'{}' is not a valid schedule format.  See --help for more information"
                .format(representation))

        return schedule

    def normalize_schedule_time(self, time_string):
        return make_aware(datetime.strptime(time_string, "%Y-%m-%d %H:%M:%S"))

    def set_periodic_task(self):
        schedule = self.schedule

        if schedule:
            begin = self.schedule_begin
            end = self.schedule_end

            schedule_map = {
                'task_interval': 'interval',
                'task_crontab': 'crontab',
                'task_datetime': 'clocked'
            }
            options = self.options.export()
            options['_user'] = self.active_user.name
            task = {
                schedule_map[schedule.facade.name]: schedule,
                'task': 'zimagi.command.exec',
                'user': self.active_user,
                'args': dump_json([self.get_full_name()]),
                'kwargs': dump_json(options)
            }
            if begin:
                task['start_time'] = begin
            if end:
                task['expires'] = end

            self._scheduled_task.store(self.get_schedule_name(), **task)

            self.success(
                "Task '{}' has been scheduled to execute periodically".format(
                    self.get_full_name()))
            return True

        return False

    def set_queue_task(self, log_key):
        def follow_progress(verbosity):
            def follow(data):
                self.message(self.create_message(data, decrypt=False),
                             verbosity=verbosity,
                             log=False)

            return False if self.manager.follow_task(
                log_key, follow) == self._log.model.STATUS_FAILED else True

        if self.background_process:
            options = self.options.export()
            options['_user'] = self.active_user.name
            options['_log_key'] = log_key

            if not self.worker_type:
                options['worker_type'] = self.spec.get('worker_type',
                                                       'default')
            try:
                self.log_status(self._log.model.STATUS_QUEUED)
                exec_command.apply_async(args=[self.get_full_name()],
                                         kwargs=options,
                                         queue=options['worker_type'])
            except OperationalError as error:
                self.error(
                    "Connection to scheduling queue could not be made.  Check service and try again: {}"
                    .format(error))

            if not self.async_exec:
                return follow_progress(options.get('verbosity', None))

            self.success(
                "Task '{}' has been pushed to the queue to execute in the background: {}"
                .format(self.get_full_name(), options))
            return True

        return False

    def wait_for_tasks(self, log_keys):
        self.manager.wait_for_tasks(log_keys)

    def publish_message(self, data, include=True):
        def _publish_message(command, data, _include):
            if getattr(command, 'log_entry', None) and _include:
                self.manager.publish_task_message(command.log_entry.name, data)

            if command.exec_parent:
                _publish_message(command.exec_parent, data, True)

        _publish_message(self, data, include)

    def publish_exit(self):
        if self.log_result and getattr(self, 'log_entry', None):
            self.manager.publish_task_exit(self.log_entry.name,
                                           self.get_status())

    def check_abort(self):
        if self.log_result and getattr(self, 'log_entry', None):
            return self.manager.check_task_abort(self.log_entry.name)
        return None

    def publish_abort(self, log_key):
        if self.log_result:
            self.manager.publish_task_abort(log_key)

    def get_schedule_name(self):
        return "{}:{}{}".format(
            self.get_full_name().replace(' ', '-'),
            datetime.now().strftime("%Y%m%d%H%M%S"),
            random.SystemRandom().choice(string.ascii_lowercase))

    def get_interval_schedule(self, representation):
        interval = self._task_interval.model
        schedule = None
        period_map = {
            'D': interval.DAYS,
            'H': interval.HOURS,
            'M': interval.MINUTES,
            'S': interval.SECONDS
        }

        match = re.match(r'^(\d+)([DHMS])$',
                         representation,
                         flags=re.IGNORECASE)
        if match:
            schedule, created = self._task_interval.store(
                representation,
                every=match.group(1),
                period=period_map[match.group(2).upper()],
            )
        return schedule

    def get_crontab_schedule(self, representation):
        schedule = None

        match = re.match(
            r'^([\*\d\-\/\,]+) ([\*\d\-\/\,]+) ([\*\d\-\/\,]+) ([\*\d\-\/\,]+) ([\*\d\-\/\,]+)$',
            representation)
        if match:
            schedule, created = self._task_crontab.store(
                representation,
                minute=match.group(1),
                hour=match.group(2),
                day_of_week=match.group(3),
                day_of_month=match.group(4),
                month_of_year=match.group(5))
        return schedule

    def get_datetime_schedule(self, representation):
        schedule = None

        match = re.match(r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$',
                         representation)
        if match:
            schedule, created = self._task_datetime.store(
                representation,
                clocked_time=make_aware(
                    datetime.strptime(representation, "%Y-%m-%d %H:%M:%S")),
            )
        return schedule
示例#16
0
class EnvironmentMixin(CommandMixin('environment')):

    @property
    def curr_env_name(self):
        return Environment.get_active_env()


    def init_environment(self):
        env = self.get_env()
        if not env or not self.data:
            self.save_env(runtime_image = None)


    def environment_field_help(self):
        defaults = Environment.get_env_defaults()
        return [
            "",
            "{} - environment runtime repository <{}>".format(self.key_color("repo"), self.value_color(defaults["repo"])),
            "{} - environment base image <{}>".format(self.key_color("base_image"), self.value_color(defaults["base_image"])),
            "{} - environment runtime image <{}>".format(self.key_color("runtime_image"), self.value_color("autogenerated")),
        ]


    def get_all_env(self):
        return Environment.get_all_env()

    def get_env(self, name = None):
        return Environment.get_env(name)

    def set_env(self, name = None, **fields):
        env_name = self.curr_env_name if name is None else name
        Environment.save_env(env_name, **fields)
        Environment.set_active_env(env_name)
        Environment.save_env_vars(env_name)
        self.success("Successfully set environment {}".format(env_name))

    def save_env(self, name = None, **fields):
        env_name = self.curr_env_name if name is None else name
        Environment.save_env(env_name, **fields)
        self.success("Successfully updated environment {}".format(env_name))

    def delete_env(self, name = None, remove_module_path = False):
        env_name = self.curr_env_name if name is None else name
        Environment.delete_env(env_name, remove_module_path = remove_module_path)
        Environment.save_env_vars()
        self.success("Successfully removed environment {}".format(env_name))


    def get_host(self, name = None):
        if not name:
            name = self.environment_host
        return self.get_instance(self._host, name, required = False)

    def create_host(self, **fields):
        name = fields.pop('name', 'temp')
        host = self._host.create(name)
        for field, value in fields.items():
            setattr(host, field, value)
        return host

    def save_host(self, **fields):
        name = fields.pop('name', self.environment_host)
        host = self.get_host(name)
        if not host:
            host = self.create_host(**{ 'name': name, **fields })
        else:
            for field, value in fields.items():
                setattr(host, field, value)
        host.save()
        return host


    def get_state(self, name, default = None):
        instance = self.get_instance(self._state, name, required = False)
        if instance:
            return instance.value
        return default

    def set_state(self, name, value = None):
        self._state.store(name, value = value)

    def delete_state(self, name = None, default = None):
        value = self.get_state(name, default)
        self._state.delete(name)
        return value

    def clear_state(self):
        self._state.clear()