Esempio n. 1
0
class CLI(object):
    def __init__(self, sm):
        """
        CLI instantiator

        :param sm: SessionManager instance to create a CLI from
        :type sm: SessionManager
        """
        self._session_manager = sm

    def _check_for_session(self):
        """Verifies that there is a current session set"""
        if self._session_manager.current_session is None:
            raise Exception(
                "You must be logged in. Register an Aquarium session.")

    def _save(self):
        """Saves the environment"""
        self._session_manager.save()

    def session(self):
        """Returns the current session"""
        return self._session_manager.current_session

    def _sessions_json(self):
        """Returns a dictionary of sessions"""
        sess_dict = {}
        for session_name, session in self._session_manager.sessions.items():
            val = str(session)
            sess_dict[session_name] = val
        return sess_dict

    def _get_categories(self):
        """Returns dictionary of category names and Library/OperationType"""
        categories = {}
        operation_types = self.session.OperationType.all()
        libraries = self._session_manager.current_session.Library.all()
        for ot in operation_types:
            category_list = categories.get(ot.category, [])
            category_list.append(ot)
            categories[ot.category] = category_list
        for lib in libraries:
            category_list = categories.get(lib.category, [])
            category_list.append(lib)
            categories[lib.category] = category_list
        return categories

    def _load(self, path_to_metadata=None):
        """Loads the environment"""
        try:
            self._session_manager.load()
            logger.cli("Environment file loaded from \"{}\"".format(
                self._session_manager.metadata.env_settings.abspath))
            logger.cli("environment loaded")
            self._save()
            self.print_env()
            return self._session_manager
        except Exception:
            raise Exception("There was an error loading session_manager")

    def print_env(self):
        logger.cli(str(self._session_manager))

    def push_all(self):
        """Save protocol"""
        self._check_for_session()
        current_env = self._session_manager.current_env
        categories = current_env.categories
        for cat in categories:
            self.push_category(cat)

    def push_category(self, category_name):
        """
        Push all :class:`OperationType` and :class:`Library` in a category.
        """
        current_env = self._session_manager.current_env
        category = current_env.get_category_dir(category_name)
        for protocol in category.list_dirs():
            self.push_one(category.name, protocol.name)

    def push_one(self, category_name, protocol_name, force=False):
        current_env = self._session_manager.current_env
        category = current_env.get_category_dir(category_name)
        protocol = current_env.get_protocol_dir(category_name, protocol_name)
        if protocol.has("source"):
            # then its a Library
            local_lib = current_env.read_library_type(category.name,
                                                      protocol.name)
            local_lib.code("source").update()
        if protocol.has("protocol"):
            # then its an OperationType
            local_ot = current_env.read_operation_type(category.name,
                                                       protocol.name)
            # TODO: unify accessing this list - used elsewhere
            for accessor in [
                    'protocol', 'precondition', 'documentation', 'cost_model'
            ]:
                code = getattr(local_ot, accessor)
                server_code = local_ot.code(accessor)

                diff_str = compare_content(server_code.content,
                                           code.content).strip()
                if diff_str != '':
                    # Local change, so fetch remote and compare
                    remote_ot = self.session.OperationType.where({
                        "category":
                        category.name,
                        "name":
                        protocol.name
                    })[0]
                    remote_code = getattr(remote_ot, accessor)
                    if code.id != remote_code.id and force is False:
                        msg = "Local version of {}/{} ({}) out of date. "
                        "Please fetch before pushing again"
                        logger.cli(
                            Fore.RED +
                            msg.format(category.name, protocol.name, accessor))
                    else:
                        logger.cli("++ Updating {}/{} ({})".format(
                            category.name, local_ot.name, accessor))
                        print(diff_str)
                        code.update()
                else:
                    logger.cli("-- No changes for {}/{} ({})".format(
                        category.name, local_ot.name, accessor))
            self._save()

    def _get_operation_types_from_sever(self, category):
        return self._session_manager.current_session.OperationType.where(
            {"category": category})

    def _get_library_types_from_server(self, category):
        return self._session_manager.current_session.Library.where(
            {"category": category})

    def fetch(self, category):
        """
        Fetch protocols from the current session and category and pull to local
        repo.
        """
        self._check_for_session()
        ots = self._get_operation_types_from_sever(category)
        libs = self._get_library_types_from_server(category)
        logger.cli("{} operation_types found".format(len(ots)))
        logger.cli("This may take awhile...")
        for ot in ots:
            logger.cli("Saving {}".format(ot.name))
            curr_env = self._session_manager.current_env
            curr_env.write_operation_type(ot)
        for lib in libs:
            logger.cli("Saving {}".format(lib.name))
            curr_env = self._session_manager.current_env
            curr_env.write_library(lib)
        self._save()

    def test(self, category_name, protocol_name, reset=False):
        """ Test a single protocol on an Aquarium Docker container """
        session = self._session_manager.current_session
        session_name = session.name

        try:
            # Start container
            self.start_container(reset)

            # Init container session
            logger.cli("Setting Docker session")
            self.register("neptune", "aquarium", "http://localhost:3001",
                          "docker")
            self.set_session("docker")

            # Copy OT from last session to container session
            logger.cli("Copying protocol from {} to docker session".format(
                session_name))
            ot = session.OperationType.find_by_name(protocol_name)
            self._copy_operation_type(session_name, "docker", ot)

            # Load data to container
            logger.cli("Loading test data into container")
            current_env = self._session_manager.current_env
            protocol = current_env.get_protocol_dir(category_name,
                                                    protocol_name)
            testing_data = docker_testing.load_data(protocol)

            # Push OT from container session to container
            logger.cli("Pushing protocol: {}".format(protocol_name))
            current_env.write_operation_type(testing_data['ot'], no_code=True)
            self.push_one(category_name, protocol_name, force=True)

            # Test protocol on container
            logger.cli("Testing protocol: {}".format(protocol_name))
            result = docker_testing.test_protocol(protocol, testing_data)

            # Remove container session
            logger.cli("Unregistering Docker session")
            self.unregister("docker")
            self.set_session(session_name)

            logger.cli("Plan success: {}".format(result['success']))
            logger.cli("View plan: {}".format(result['plan_url']))
        except Exception:
            self.unregister("docker")
            self.set_session(session_name)
            raise

    def start_container(self, reset=False):
        """Start an Aqarium Docker container"""
        container_id = self._session_manager.get_container_id()
        result = docker_testing.start_container(reset, container_id)

        if result['success']:
            logger.cli("Container started!")
        else:
            logger.cli("Container already running: use --reset to restart")
        self._session_manager.set_container_id(result['id'])

    def stop_container(self):
        """Stop an Aqarium Docker container"""
        container_id = self._session_manager.get_container_id()
        if container_id != '':
            docker_testing.stop_container(container_id)
            self._session_manager.set_container_id('')
            logger.cli("Container killed")
        else:
            logger.cli("No container is currently running")

    def _copy_operation_type(self, sess1_name, sess2_name, ot):
        """Copy Operation Type files from one session to another"""
        sess1 = self._session_manager.get(sess1_name)
        sess2 = self._session_manager.get(sess2_name)

        # Copy all files from current_session to docker
        rmtree(sess2.abspath)
        copytree(sess1.abspath, sess2.abspath)

        # Make this operation type visible to ODir
        sess2.get_operation_type_dir(ot.category, ot.name)

    @property
    def _sessions(self):
        return self._session_manager.sessions

    @property
    def sessions(self):
        """List the current available sessions"""
        sessions = self._sessions
        if len(sessions) == 0:
            logger.cli("There are no sessions. "
                       "Use 'pfish register' to register a session. "
                       "Use 'pfish register --h' for help.")
        else:
            logger.cli(format_json(self._sessions_json()))
        return sessions

    def protocols(self):
        """Print category and protocol names"""
        env = self._session_manager.current_env
        cats = env.categories
        for cat in cats:
            for code in cat.dirs:
                logger.cli(cat.name + "/" + code.name)

    def categories(self):
        """ Get all available categories and count """
        self._check_for_session()
        logger.cli("Getting category counts:")
        categories = self._get_categories()
        category_count = {k: len(v) for k, v in categories.items()}
        logger.cli(format_json(category_count))

    def set_session(self, session_name):
        """
        Set the session by name.
        Use "sessions" to find all available sessions.
        """
        sessions = self._session_manager.sessions
        if session_name not in sessions:
            msg = "Session \"{}\" not in available sessions ({})"
            logger.error(msg.format(session_name, ', '.join(sessions.keys())))

        logger.cli("Setting session to \"{}\"".format(session_name))
        self._session_manager.set_current(session_name)
        self._save()

    def set_repo(self, path):
        repo_dir = os.path.dirname(path)
        repo_name = os.path.basename(path)
        logger.cli("Setting repo to \"{}\"".format(path))
        self._session_manager = SessionManager(repo_dir, name=repo_name)

        logger.cli("Loading environment...")
        self._save()
        self._load()
        logger.cli("{} session(s) loaded".format(len(self._sessions)))

    def move_repo(self, path):
        """Moves the current repo to another location."""
        path = Path(path).absolute()
        if not path.is_dir():
            raise Exception("Path {} does not exist".format(str(path)))
        logger.cli("Moving repo location from {} to {}".format(
            self._session_manager.abspath, path))
        self._session_manager.move_repo(path)
        self._save()

    def reset(self):
        self._session_manager.metadata.rmdirs()
        self._session_manager.rmdirs()
        self._session_manager.save(force_new_key=True)

    def generate_encryption_key(self):
        key = Fernet.generate_key().decode()
        logger.warning(
            "SAVE KEY IN A SECURE PLACE TO USE YOUR REPO ON ANOTHER COMPUTER. "
            "IT WILL NOT APPEAR AGAIN.")
        logger.warning("NEW KEY: {}".format(key))
        self.__update_encryption_key(key)

    def __update_encryption_key(self, new_key):
        logger.cli("UPDATING KEYS AND PASSWORD HASHES")
        self._session_manager.update_encryption_key(new_key)
        self.set_encryption_key(new_key)

    def set_encryption_key(self, key):
        """
        Sets the encryption key to use for the managed folder.

        :param key:
        :type key:
        :return:
        :rtype:
        """
        logger.cli("Encryption key set")
        self._session_manager.save(force_new_key=True, key=key)
        self._session_manager = SessionManager('').load()
        self._save()
        logger.cli("Loaded sessions with new key")
        return self.sessions

    def register(self, login, password, aquarium_url, name):
        """
        Registers a new session, creating a new managed folder.

        :param login: aquarium login
        :type login: str
        :param password: aquarium password
        :type password: str
        :param aquarium_url: aquarium url
        :type aquarium_url: str
        :param name: name to give to the new session
        :type name: str
        :return: None
        :rtype: NOne
        """
        try:
            self._session_manager.register_session(login, password,
                                                   aquarium_url, name)
        except InvalidSchema:
            raise InvalidSchema(
                "Missing schema for {}. Did you forget the \"http://\"?".
                format(aquarium_url))
        logger.cli("registering session: {}".format(name))
        self.set_session(name)
        self._save()

    def unregister(self, name):
        """
        Unregisters a session by name

        :param name: name of session to unregister
        :type name: str
        :return: None
        :rtype: NOne
        """
        if name in self._session_manager.sessions:
            logger.cli("Unregistering {}: {}".format(
                name, str(self._session_manager.get_session(name))))
            self._session_manager.remove_session(name)
        else:
            logger.cli("Session {} does not exist".format(name))
        self._save()

    def ls(self):
        """List dictionary structure for the session manager"""
        logger.cli(str(self._session_manager.abspath))
        logger.cli('\n' + self._session_manager.show())

    def repo(self):
        """Prints the location of the managed folder"""
        logger.cli(str(self._session_manager.abspath))

    def shell(self):
        """Opens an interactive shell"""
        logger.cli("Opening new shell")
        Shell(self).run()

    def __str__(self):
        return str(self._session_manager)
Esempio n. 2
0
class ParrotFish(object):
    """Contains command and methods for updating and pushing protocols from Aquarium. Automatically
    exposes commands to the command line interface (CLI) using Fire.
    """

    controllers = [(OperationType, "protocol"), (Library, "source")]

    def __init__(self, session_manager=None):
        """
        ParrotFish constructor

        :param session_manager: The session manager to use
        :type session_manager: SessionManager
        """
        self.session_manager = session_manager

    def check_for_session(self):
        if self.session_manager.current is None:
            raise Exception(
                "You must be logged in. Register an Aquarium session.")

    def save(self):
        """Saves the environment"""
        self.session_manager.save()

    def print_env(self):
        logger.cli(str(self.session_manager))

    def load(self):
        """Loads the environment"""
        if not self.session_manager.metadata.exists():
            self.session_manager = SessionManager('.')
        else:
            self.session_manager = SessionManager.load()
            logger.cli("Environment file loaded ({})".format(
                self.session_manager.metadata.abspath))
        logger.cli("environment loaded")
        self.print_env()
        return self.session_manager

    def push_category(self, category):
        """Push all :class:`OperationType` and :class:`Library` in a category"""
        current_env = self.session_manager.current_env
        for protocol in category.list_dirs():
            if protocol.has("source"):
                # then its a Library
                local_lib = current_env.read_library_type(
                    category.name, protocol.name)
                local_lib.code("source").update()
            if protocol.has("protocol"):
                # then its an OperationType
                local_ot = current_env.read_operation_type(
                    category.name, protocol.name)
                for accessor in [
                        'protocol', 'precondition', 'documentation',
                        'cost_model'
                ]:
                    logger.cli("Updating {}/{} ({})".format(
                        category.name, local_ot.name, accessor))
                    code = getattr(local_ot, accessor)
                    code.update()
        self.save()

    def push_all(self):
        """Save protocol"""
        self.check_for_session()
        current_env = self.session_manager.current_env
        categories = current_env.categories
        for cat in categories:
            self.push_category(cat)

    #
    #
    # def fetch_parent_from_server(self, code_file):
    #     code_parent = code_file.code_parent
    #
    #     id_ = {"id": code_parent.id}
    #     aq_code_parents = code_parent.where_callback(code_parent.__class__.__name__, id_)
    #     if aq_code_parents is None or aq_code_parents == []:
    #         logger.warning("Could not find {} with {}".format(
    #             code_parent.__name__, id_))
    #     elif len(aq_code_parents) > 1:
    #         logger.warning("More than one {} found with {}".format(
    #             code_parent.__name__, id_))
    #     else:
    #         return aq_code_parents[0]
    #
    # @staticmethod
    # def get_controller_interface(model_name):
    #     return self.session_manager.current.model_interface(model_name)
    #
    #
    # def fetch_content(self, controller_instance):
    #     """
    #     Fetches code content from Aquarium model (e.g. ot.code('protocol')
    #
    #     :param code_container: Model that contains code
    #     :type code_container: OperationType or Library
    #     :return:
    #     :rtype:
    #     """
    #     for controller, accessor in self.controllers:
    #         if type(controller_instance) is controller:
    #             return controller_instance.code(accessor).content
    #
    #
    # def ok_to_push(self, code_file):
    #     # Check last modified
    #     # modified_at = code_file.abspath.stat().st_mtime
    #     # if code_file.created_at == modified_at:
    #     #     logger.verbose("File has not been modified")
    #     #     return False
    #     fetched_content = code_file.code.content
    #
    #     # Check local changes
    #     local_content = code_file.read('r')
    #     local_changes = compare_content(local_content, fetched_content)
    #     if not local_changes:
    #         logger.verbose("<{}/{}> there are no local changes.".format(
    #             code_file.code_parent.category,
    #             code_file.code_parent.name))
    #         return False
    #
    #     # Check server changes
    #     server_content = self.fetch_content(
    #         self.fetch_parent_from_server(code_file))
    #     server_changes = compare_content(local_content, server_content)
    #     if not server_changes:
    #         logger.verbose("<{}/{}> there are not any differences between local and server.".format(
    #             code_file.code_parent.category,
    #             code_file.code_parent.name))
    #         return False
    #     return True
    #
    #
    # def get_code_files(self):
    #     if not self.session_managerironment().session_dir().exists():
    #         logger.cli("There are no protocols in this repo.")
    #         return []
    #     files = self.session_managerironment().session_dir().files
    #     code_files = [f for f in files if hasattr(f, 'code_parent')]
    #     return code_files

    def session(self):
        """Returns the current session"""
        return self.session_manager.current

    def sessions_json(self):
        """Returns a dictionary of sessions"""
        sess_dict = {}
        for session_name, session in self.session_manager.sessions.items():
            val = str(session)
            sess_dict[session_name] = val
        return sess_dict

    def get_categories(self):
        """Returns dictionary of category names and Library/OperationType"""
        categories = {}
        operation_types = self.session_manager.current.OperationType.all()
        libraries = self.session_manager.current.Library.all()
        for ot in operation_types:
            l = categories.get(ot.category, [])
            l.append(ot)
            categories[ot.category] = l
        for lib in libraries:
            l = categories.get(lib.category, [])
            l.append(lib)
            categories[lib.category] = l
        return categories

    def fetch(self, category):
        """ Fetch protocols from the current session & category and pull to local repo. """
        self.check_for_session()
        ots = self.session_manager.current.OperationType.where(
            {"category": category})
        libs = self.session_manager.current.Library.where(
            {"category": category})
        logger.cli("{} operation_types found".format(len(ots)))
        logger.cli("This may take awhile...")
        for ot in ots:
            logger.cli("Saving {}".format(ot.name))
            curr_env = self.session_manager.current_env
            curr_env.write_operation_type(ot)
        for lib in libs:
            logger.cli("Saving {}".format(lib.name))
            curr_env = self.session_manager.current_env
            curr_env.write_library(lib)
        self.save()

    def sessions(self):
        """List the current available sessions"""
        sessions = self.session_manager.sessions
        logger.cli(format_json(self.sessions_json()))
        return sessions

    # def state(self):
    #     """ Get the current environment state. """
    #     logger.cli(format_json({
    #         "session": "{}: {}".format(
    #             self.session_manager.session_name,
    #             str(self.session_manager.current)),
    #         "sessions": self.sessions_json(),
    #         "category": self.session_manager.category,
    #         "repo": str(self.session_manager.repo.abspath)
    #     }))
    #

    def protocols(self):
        """Print category and protocol names"""
        env = self.session_manager.current_env
        cats = env.protocols.dirs
        for cat in env.protocols.dirs:
            for code in cat.dirs:
                logger.cli(cat.name + "/" + code.name)

    #

    #
    #
    # def protocols(self):
    #     """ Get and count the number of protocols on the local machine """
    #     self.check_for_session()
    #     files = self.get_code_files()
    #     logger.cli(format_json(
    #         ['/'.join([f.code_parent.category, f.code_parent.name]) for f in files]))
    #
    #
    #
    # def category(self):
    #     """ Get the current category of the environment. """
    #     logger.cli(self.session_manager.category)
    #     return self.session_manager.category
    #
    #
    #
    # def set_category(self, category):
    #     """ Set the category of the environment. Set "all" to use all categories. Use "categories" to find all
    #     categories. """
    #     self.check_for_session()
    #     logger.cli("Setting category to \"{}\"".format(category))
    #     self.session_manager.category = category
    #     self.save()

    def categories(self):
        """ Get all available categories and count """
        self.check_for_session()
        logger.cli("Getting category counts:")
        categories = self.get_categories()
        category_count = {k: len(v) for k, v in categories.items()}
        logger.cli(format_json(category_count))

    def set_session(self, session_name):
        """ Set the session by name. Use "sessions" to find all available sessions. """
        sessions = self.session_manager.sessions
        if session_name not in sessions:
            logger.error(
                "Session \"{}\" not in available sessions ({})".format(
                    session_name, ', '
                    ''.join(sessions.keys())))

        logger.cli("Setting session to \"{}\"".format(session_name))
        self.session_manager.set_current(session_name)
        self.save()

    def set(self, session_name, category):
        """Sets the session and category"""
        self.set_session(session_name)
        self.set_category(category)

    def move_repo(self, path):
        """Moves the current repo to another location."""
        path = Path(path).absolute()
        if not path.is_dir():
            raise Exception("Path {} does not exist".format(str(path)))
        logger.cli("Moving repo location from {} to {}".format(
            self.session_manager.abspath, path))
        self.session_manager.move_repo(path)
        self.save()

    def register(self, login, password, aquarium_url, session_name):
        """Registers a new session."""
        try:
            self.session_manager.register_session(login, password,
                                                  aquarium_url, session_name)
        except InvalidSchema:
            raise InvalidSchema(
                "Missing schema for {}. Did you forget the \"http://\"?".
                format(aquarium_url))
        logger.cli("registering session: {}".format(session_name))
        self.save()
        return self

    def unregister(self, name):
        if name in self.session_manager.sessions:
            logger.cli("Unregistering {}: {}".format(
                name, str(self.session_manager.get_session(name))))
            self.session_manager.remove_session(name)
        else:
            logger.cli("Session {} does not exist".format(name))
        self.save()

    #
    # def clear_history(self):
    #     if env_data.exists():
    #         os.remove(env_data.abspath)
    #     logger.cli("Cleared history")

    def ls(self):
        """List dictionary structure for the session manager"""
        logger.cli(str(self.session_manager.abspath))
        logger.cli('\n' + self.session_manager.show())