Example #1
0
    def __init__(self, tmpdir):
        # type: (LocalPath) -> None
        self.settings = Settings()
        self.settings.database = db_url(tmpdir)
        self.plugins = PluginProxy([])

        # Reinitialize the global plugin proxy with an empty set of plugins in case a previous test
        # initialized plugins.  This can go away once a plugin proxy is injected into everything
        # that needs it instead of maintained as a global.
        set_global_plugin_proxy(self.plugins)

        self.initialize_database()
        self.session = SessionFactory(self.settings).create_session()
        self.graph = GroupGraph()
        session_factory = SingletonSessionFactory(self.session)
        self.repository_factory = GraphRepositoryFactory(
            self.settings, self.plugins, session_factory, self.graph)
        self.sql_repository_factory = SQLRepositoryFactory(
            self.settings, self.plugins, session_factory)
        self.service_factory = ServiceFactory(self.settings, self.plugins,
                                              self.repository_factory)
        self.usecase_factory = UseCaseFactory(self.settings, self.plugins,
                                              self.service_factory)
        self._transaction_service = self.service_factory.create_transaction_service(
        )
Example #2
0
 def open_database(self) -> None:
     self.session = SessionFactory(self.settings).create_session()
     session_factory = SingletonSessionFactory(self.session)
     self.repository_factory = GraphRepositoryFactory(
         self.settings, self.plugins, session_factory, self.graph)
     self.sql_repository_factory = SQLRepositoryFactory(
         self.settings, self.plugins, session_factory)
     self.service_factory = ServiceFactory(self.settings, self.plugins,
                                           self.repository_factory)
     self.usecase_factory = UseCaseFactory(self.settings, self.plugins,
                                           self.service_factory)
     self._transaction_service = self.service_factory.create_transaction_service(
     )
Example #3
0
def create_sql_usecase_factory(settings, plugins, session_factory=None):
    # type: (Settings, PluginProxy, Optional[SessionFactory]) -> UseCaseFactory
    """Create a SQL-backed UseCaseFactory, with optional injection of a Session.

    Session factory injection is supported primarily for tests.  If not injected, it will be
    created on demand.
    """
    if not session_factory:
        session_factory = SessionFactory(settings)
    repository_factory = SQLRepositoryFactory(settings, plugins, session_factory)
    service_factory = ServiceFactory(settings, plugins, repository_factory)
    return UseCaseFactory(settings, plugins, service_factory)
Example #4
0
def create_graph_usecase_factory(
    settings,  # type: Settings
    plugins,  # type: PluginProxy
    session_factory=None,  # type: Optional[SessionFactory]
    graph=None,  # type: Optional[GroupGraph]
):
    # type: (...) -> UseCaseFactory
    """Create a graph-backed UseCaseFactory, with optional injection of a Session and GroupGraph.

    Session factory and graph injection is supported primarily for tests.  If not injected, they
    will be created on demand.
    """
    if not session_factory:
        session_factory = SessionFactory(settings)
    repository_factory = GraphRepositoryFactory(settings, plugins, session_factory, graph)
    service_factory = ServiceFactory(settings, plugins, repository_factory)
    return UseCaseFactory(settings, plugins, service_factory)
Example #5
0
    def __init__(self, tmpdir):
        # type: (LocalPath) -> None
        self.settings = Settings()
        self.settings.database = db_url(tmpdir)
        self.plugin_proxy = PluginProxy([])

        # Reinitialize the global plugin proxy with an empty set of plugins in case a previous test
        # initialized plugins.  This can go away once a plugin proxy is injected into everything
        # that needs it instead of maintained as a global.
        set_global_plugin_proxy(self.plugin_proxy)

        self.initialize_database()
        self.session = SessionFactory(self.settings).create_session()
        self.graph = GroupGraph()
        self.repository_factory = GraphRepositoryFactory(
            self.settings, self.plugin_proxy, SingletonSessionFactory(self.session), self.graph
        )
        self.service_factory = ServiceFactory(self.repository_factory)
        self.usecase_factory = UseCaseFactory(self.settings, self.service_factory)
        self._transaction_service = self.service_factory.create_transaction_service()
Example #6
0
def group_command(args: Namespace, settings: CtlSettings,
                  session_factory: SessionFactory) -> None:
    session = session_factory.create_session()
    group = session.query(Group).filter_by(groupname=args.groupname).scalar()
    if not group:
        logging.error("No such group %s", args.groupname)
        return

    if args.subcommand in ["add_member", "remove_member"]:
        # somewhat hacky: using function instance to use `ensure_valid_username` only on
        # these subcommands
        @ensure_valid_username
        def call_mutate(args: Namespace, settings: CtlSettings,
                        session_factory: SessionFactory) -> None:
            mutate_group_command(session, group, args)

        call_mutate(args, settings, session_factory)

    elif args.subcommand == "log_dump":
        logdump_group_command(session, group, args)
Example #7
0
class SetupTest(object):
    """Set up the environment for a test.

    Most actions should be done inside of a transaction, created via the transaction() method and
    used as a context handler.  This will ensure that the test setup is committed to the database
    before the test starts running.

    Attributes:
        settings: Settings object for tests (only the database is configured)
        graph: Underlying graph (not refreshed from the database automatically!)
        session: The underlying database session
        plugin_proxy: The plugin proxy used for the tests
        repository_factory: Factory for repository objects
        service_factory: Factory for service objects
        usecase_factory: Factory for usecase objects
    """

    def __init__(self, tmpdir):
        # type: (LocalPath) -> None
        self.settings = Settings()
        self.settings.database = db_url(tmpdir)
        self.plugin_proxy = PluginProxy([])

        # Reinitialize the global plugin proxy with an empty set of plugins in case a previous test
        # initialized plugins.  This can go away once a plugin proxy is injected into everything
        # that needs it instead of maintained as a global.
        set_global_plugin_proxy(self.plugin_proxy)

        self.initialize_database()
        self.session = SessionFactory(self.settings).create_session()
        self.graph = GroupGraph()
        self.repository_factory = GraphRepositoryFactory(
            self.settings, self.plugin_proxy, SingletonSessionFactory(self.session), self.graph
        )
        self.service_factory = ServiceFactory(self.repository_factory)
        self.usecase_factory = UseCaseFactory(self.settings, self.service_factory)
        self._transaction_service = self.service_factory.create_transaction_service()

    def initialize_database(self):
        # type: () -> Session
        schema_repository = SchemaRepository(self.settings)

        # If using a persistent database, clear the database first.
        if "MEROU_TEST_DATABASE" in os.environ:
            schema_repository.drop_schema()

        # Create the database schema.
        schema_repository.initialize_schema()

    def close(self):
        # type: () -> None
        self.session.close()

    @contextmanager
    def transaction(self):
        # type: () -> Iterator[None]
        with self._transaction_service.transaction():
            yield
        self.graph.update_from_db(self.session)

    def create_group(self, name, description="", join_policy=GroupJoinPolicy.CAN_ASK):
        # type: (str, str, GroupJoinPolicy) -> None
        """Create a group, does nothing if it already exists."""
        group_service = self.service_factory.create_group_service()
        if not group_service.group_exists(name):
            group_service.create_group(name, description, join_policy)

    def create_permission(
        self, name, description="", audited=False, enabled=True, created_on=None
    ):
        # type: (str, str, bool, bool, Optional[datetime]) -> None
        """Create a permission, does nothing if it already exists."""
        permission_repository = self.repository_factory.create_permission_repository()
        if not permission_repository.get_permission(name):
            permission_repository.create_permission(
                name, description, audited, enabled, created_on
            )

    def create_user(self, name):
        # type: (str) -> None
        """Create a user, does nothing if it already exists."""
        if User.get(self.session, name=name):
            return
        user = User(username=name)
        user.add(self.session)

    def add_group_to_group(self, member, group, expiration=None):
        # type: (str, str, Optional[datetime]) -> None
        self.create_group(member)
        self.create_group(group)
        member_obj = Group.get(self.session, name=member)
        assert member_obj
        group_obj = Group.get(self.session, name=group)
        assert group_obj
        edge = GroupEdge(
            group_id=group_obj.id,
            member_type=OBJ_TYPES["Group"],
            member_pk=member_obj.id,
            expiration=expiration,
            active=True,
            _role=GROUP_EDGE_ROLES.index("member"),
        )
        edge.add(self.session)

    def add_user_to_group(self, user, group, role="member", expiration=None):
        # type: (str, str, str, Optional[datetime]) -> None
        self.create_user(user)
        self.create_group(group)
        user_obj = User.get(self.session, name=user)
        assert user_obj
        group_obj = Group.get(self.session, name=group)
        assert group_obj
        edge = GroupEdge(
            group_id=group_obj.id,
            member_type=OBJ_TYPES["User"],
            member_pk=user_obj.id,
            expiration=expiration,
            active=True,
            _role=GROUP_EDGE_ROLES.index(role),
        )
        edge.add(self.session)

    def grant_permission_to_group(self, permission, argument, group):
        # type: (str, str, str) -> None
        self.create_group(group)
        self.create_permission(permission)
        group_service = self.service_factory.create_group_service()
        group_service.grant_permission_to_group(permission, argument, group)

    def revoke_permission_from_group(self, permission, argument, group):
        # type: (str, str, str) -> None
        permission_obj = Permission.get(self.session, name=permission)
        assert permission_obj
        group_obj = Group.get(self.session, name=group)
        assert group_obj
        self.session.query(PermissionMap).filter(
            PermissionMap.permission_id == permission_obj.id,
            PermissionMap.group_id == group_obj.id,
            PermissionMap.argument == argument,
        ).delete()

    def create_group_request(self, user, group, role="member"):
        # type: (str, str, str) -> None
        self.create_user(user)
        self.create_group(group)

        user_obj = User.get(self.session, name=user)
        assert user_obj
        group_obj = Group.get(self.session, name=group)
        assert group_obj

        # Note: despite the function name, this only creates the request. The flow here is
        # convoluted enough that it seems best to preserve exact behavior for testing.
        group_obj.add_member(
            requester=user_obj, user_or_group=user_obj, reason="", status="pending", role=role
        )

    def create_service_account(self, service_account, owner, description="", machine_set=""):
        # type: (str, str, str, str) -> None
        self.create_group(owner)
        group_obj = Group.get(self.session, name=owner)
        assert group_obj

        if User.get(self.session, name=service_account):
            return
        user = User(username=service_account)
        user.add(self.session)
        service_account_obj = ServiceAccount(
            user_id=user.id, description=description, machine_set=machine_set
        )
        service_account_obj.add(self.session)
        user.is_service_account = True

        self.session.flush()
        owner_map = GroupServiceAccount(
            group_id=group_obj.id, service_account_id=service_account_obj.id
        )
        owner_map.add(self.session)

    def grant_permission_to_service_account(self, permission, argument, service_account):
        # type: (str, str, str) -> None
        self.create_permission(permission)
        permission_obj = Permission.get(self.session, name=permission)
        assert permission_obj
        user_obj = User.get(self.session, name=service_account)
        assert user_obj, "Must create the service account first"
        assert user_obj.is_service_account
        grant = ServiceAccountPermissionMap(
            permission_id=permission_obj.id,
            service_account_id=user_obj.service_account.id,
            argument=argument,
        )
        grant.add(self.session)

    def disable_user(self, user):
        # type: (str) -> None
        user_repository = self.repository_factory.create_user_repository()
        user_repository.disable_user(user)

    def disable_group(self, group):
        # type: (str) -> None
        group_obj = Group.get(self.session, name=group)
        assert group_obj
        group_obj.enabled = False

    def disable_service_account(self, service_account):
        # type: (str) -> None
        service_obj = ServiceAccount.get(self.session, name=service_account)
        assert service_obj
        service_obj.user.enabled = False
        service_obj.owner.delete(self.session)
        permissions = self.session.query(ServiceAccountPermissionMap).filter_by(
            service_account_id=service_obj.id
        )
        for permission in permissions:
            permission.delete(self.session)

    def create_role_user(self, role_user, description="", join_policy=GroupJoinPolicy.CAN_ASK):
        # type: (str, str, GroupJoinPolicy) -> None
        """Create an old-style role user.

        This concept is obsolete and all code related to it will be deleted once all remaining
        legacy role users have been converted to service accounts.  This method should be used only
        for tests to maintain backward compatibility until that happens.
        """
        user = User(username=role_user, role_user=True)
        user.add(self.session)
        self.create_group(role_user, description, join_policy)
        self.add_user_to_group(role_user, role_user)
Example #8
0
class SetupTest(object):
    """Set up the environment for a test.

    Most actions should be done inside of a transaction, created via the transaction() method and
    used as a context handler.  This will ensure that the test setup is committed to the database
    before the test starts running.

    Attributes:
        settings: Settings object for tests (only the database is configured)
        graph: Underlying graph (not refreshed from the database automatically!)
        session: The underlying database session
        plugins: The plugin proxy used for the tests
        repository_factory: Factory for repository objects
        sql_repository_factory: Factory that returns only SQL repository objects (no graph)
        service_factory: Factory for service objects
        usecase_factory: Factory for usecase objects
    """
    def __init__(self, tmpdir):
        # type: (LocalPath) -> None
        self.settings = Settings()
        self.settings.database = db_url(tmpdir)
        self.plugins = PluginProxy([])

        # Reinitialize the global plugin proxy with an empty set of plugins in case a previous test
        # initialized plugins.  This can go away once a plugin proxy is injected into everything
        # that needs it instead of maintained as a global.
        set_global_plugin_proxy(self.plugins)

        self.initialize_database()
        self.session = SessionFactory(self.settings).create_session()
        self.graph = GroupGraph()
        session_factory = SingletonSessionFactory(self.session)
        self.repository_factory = GraphRepositoryFactory(
            self.settings, self.plugins, session_factory, self.graph)
        self.sql_repository_factory = SQLRepositoryFactory(
            self.settings, self.plugins, session_factory)
        self.service_factory = ServiceFactory(self.settings, self.plugins,
                                              self.repository_factory)
        self.usecase_factory = UseCaseFactory(self.settings, self.plugins,
                                              self.service_factory)
        self._transaction_service = self.service_factory.create_transaction_service(
        )

    def initialize_database(self):
        # type: () -> Session
        schema_repository = SchemaRepository(self.settings)

        # If using a persistent database, clear the database first.
        if "MEROU_TEST_DATABASE" in os.environ:
            schema_repository.drop_schema()

        # Create the database schema.
        schema_repository.initialize_schema()

    def close(self):
        # type: () -> None
        self.session.close()

    @contextmanager
    def transaction(self):
        # type: () -> Iterator[None]
        with self._transaction_service.transaction():
            yield
        self.graph.update_from_db(self.session)

    def create_group(self,
                     name,
                     description="",
                     join_policy=GroupJoinPolicy.CAN_ASK):
        # type: (str, str, GroupJoinPolicy) -> None
        """Create a group, does nothing if it already exists."""
        group_service = self.service_factory.create_group_service()
        if not group_service.group_exists(name):
            group_service.create_group(name, description, join_policy)

    def create_permission(self,
                          name,
                          description="",
                          audited=False,
                          enabled=True,
                          created_on=None):
        # type: (str, str, bool, bool, Optional[datetime]) -> None
        """Create a permission, does nothing if it already exists."""
        permission_repository = self.repository_factory.create_permission_repository(
        )
        if not permission_repository.get_permission(name):
            permission_repository.create_permission(name, description, audited,
                                                    enabled, created_on)

    def create_user(self, name):
        # type: (str) -> None
        """Create a user, does nothing if it already exists."""
        if User.get(self.session, name=name):
            return
        user = User(username=name)
        user.add(self.session)

    def add_group_to_group(self, member, group, expiration=None):
        # type: (str, str, Optional[datetime]) -> None
        self.create_group(member)
        self.create_group(group)
        member_obj = Group.get(self.session, name=member)
        assert member_obj
        group_obj = Group.get(self.session, name=group)
        assert group_obj
        edge = GroupEdge(
            group_id=group_obj.id,
            member_type=OBJ_TYPES["Group"],
            member_pk=member_obj.id,
            expiration=expiration,
            active=True,
            _role=GROUP_EDGE_ROLES.index("member"),
        )
        edge.add(self.session)

    def add_user_to_group(self, user, group, role="member", expiration=None):
        # type: (str, str, str, Optional[datetime]) -> None
        self.create_user(user)
        self.create_group(group)
        user_obj = User.get(self.session, name=user)
        assert user_obj
        group_obj = Group.get(self.session, name=group)
        assert group_obj
        edge = GroupEdge(
            group_id=group_obj.id,
            member_type=OBJ_TYPES["User"],
            member_pk=user_obj.id,
            expiration=expiration,
            active=True,
            _role=GROUP_EDGE_ROLES.index(role),
        )
        edge.add(self.session)

    def grant_permission_to_group(self, permission, argument, group):
        # type: (str, str, str) -> None
        self.create_group(group)
        self.create_permission(permission)
        group_service = self.service_factory.create_group_service()
        group_service.grant_permission_to_group(permission, argument, group)

    def revoke_permission_from_group(self, permission, argument, group):
        # type: (str, str, str) -> None
        permission_obj = Permission.get(self.session, name=permission)
        assert permission_obj
        group_obj = Group.get(self.session, name=group)
        assert group_obj
        self.session.query(PermissionMap).filter(
            PermissionMap.permission_id == permission_obj.id,
            PermissionMap.group_id == group_obj.id,
            PermissionMap.argument == argument,
        ).delete()

    def create_group_request(self, user, group, role="member"):
        # type: (str, str, str) -> None
        self.create_user(user)
        self.create_group(group)

        user_obj = User.get(self.session, name=user)
        assert user_obj
        group_obj = Group.get(self.session, name=group)
        assert group_obj

        # Note: despite the function name, this only creates the request. The flow here is
        # convoluted enough that it seems best to preserve exact behavior for testing.
        group_obj.add_member(requester=user_obj,
                             user_or_group=user_obj,
                             reason="",
                             status="pending",
                             role=role)

    def create_service_account(self,
                               service_account,
                               owner,
                               machine_set="",
                               description=""):
        # type: (str, str, str, str) -> None
        self.create_group(owner)
        service_account_repository = self.repository_factory.create_service_account_repository(
        )
        service_account_repository.create_service_account(
            service_account, owner, machine_set, description)

    def grant_permission_to_service_account(self, permission, argument,
                                            service_account):
        # type: (str, str, str) -> None
        self.create_permission(permission)
        permission_obj = Permission.get(self.session, name=permission)
        assert permission_obj
        user_obj = User.get(self.session, name=service_account)
        assert user_obj, "Must create the service account first"
        assert user_obj.is_service_account
        grant = ServiceAccountPermissionMap(
            permission_id=permission_obj.id,
            service_account_id=user_obj.service_account.id,
            argument=argument,
        )
        grant.add(self.session)

    def add_metadata_to_user(self, key, value, user):
        # type: (str, str, str) -> None
        sql_user = User.get(self.session, name=user)
        assert sql_user
        metadata = UserMetadata(user_id=sql_user.id,
                                data_key=key,
                                data_value=value)
        metadata.add(self.session)

    def add_public_key_to_user(self, key, user):
        # type: (str, str) -> None
        sql_user = User.get(self.session, name=user)
        assert sql_user
        public_key = SSHKey(key, strict=True)
        public_key.parse()
        sql_public_key = PublicKey(
            user_id=sql_user.id,
            public_key=public_key.keydata.strip(),
            fingerprint=public_key.hash_md5().replace("MD5:", ""),
            fingerprint_sha256=public_key.hash_sha256().replace("SHA256:", ""),
            key_size=public_key.bits,
            key_type=public_key.key_type,
            comment=public_key.comment,
        )
        sql_public_key.add(self.session)

    def disable_user(self, user):
        # type: (str) -> None
        user_repository = self.repository_factory.create_user_repository()
        user_repository.disable_user(user)

    def disable_group(self, group):
        # type: (str) -> None
        group_obj = Group.get(self.session, name=group)
        assert group_obj
        group_obj.enabled = False

    def disable_service_account(self, service_account):
        # type: (str) -> None
        service_obj = ServiceAccount.get(self.session, name=service_account)
        assert service_obj
        service_obj.user.enabled = False
        service_obj.owner.delete(self.session)
        permissions = self.session.query(
            ServiceAccountPermissionMap).filter_by(
                service_account_id=service_obj.id)
        for permission in permissions:
            permission.delete(self.session)

    def create_role_user(self,
                         role_user,
                         description="",
                         join_policy=GroupJoinPolicy.CAN_ASK):
        # type: (str, str, GroupJoinPolicy) -> None
        """Create an old-style role user.

        This concept is obsolete and all code related to it will be deleted once all remaining
        legacy role users have been converted to service accounts.  This method should be used only
        for tests to maintain backward compatibility until that happens.
        """
        user = User(username=role_user, role_user=True)
        user.add(self.session)
        self.create_group(role_user, description, join_policy)
        self.add_user_to_group(role_user, role_user)
Example #9
0
def main(sys_argv=sys.argv, session=None):
    # type: (List[str], Optional[Session]) -> None
    description_msg = "Grouper Control"
    parser = argparse.ArgumentParser(description=description_msg)

    parser.add_argument("-c",
                        "--config",
                        default=default_settings_path(),
                        help="Path to config file.")
    parser.add_argument("-d",
                        "--database-url",
                        type=str,
                        default=None,
                        help="Override database URL in config.")
    parser.add_argument("-q",
                        "--quiet",
                        action="count",
                        default=0,
                        help="Decrease logging verbosity.")
    parser.add_argument("-v",
                        "--verbose",
                        action="count",
                        default=0,
                        help="Increase logging verbosity.")
    parser.add_argument(
        "-V",
        "--version",
        action="version",
        version="%%(prog)s %s" % __version__,
        help="Display version information.",
    )

    subparsers = parser.add_subparsers(dest="command")
    CtlCommandFactory.add_all_parsers(subparsers)

    # Add parsers for legacy commands that have not been refactored.
    for subcommand_module in [group, oneoff, service_account, shell]:
        subcommand_module.add_parser(subparsers)  # type: ignore

    args = parser.parse_args(sys_argv[1:])

    # Construct the CtlSettings object used for all commands, and set it as the global Settings
    # object.  All code in grouper.ctl.* takes the CtlSettings object as an argument if needed, but
    # it may call other legacy code that requires the global Settings object be present.
    settings = CtlSettings.global_settings_from_config(args.config)
    if args.database_url:
        settings.database = args.database_url

    # Construct a session factory, which is passed into all the legacy commands that haven't been
    # converted to usecases yet.
    if session:
        session_factory = SingletonSessionFactory(
            session)  # type: SessionFactory
    else:
        session_factory = SessionFactory(settings)

    log_level = get_loglevel(args, base=logging.INFO)
    logging.basicConfig(level=log_level, format=settings.log_format)

    if log_level < 0:
        sa_log.setLevel(logging.INFO)

    # Initialize plugins.  The global plugin proxy is used by legacy code.
    try:
        plugins = PluginProxy.load_plugins(settings, "grouper-ctl")
    except PluginsDirectoryDoesNotExist as e:
        logging.fatal("Plugin directory does not exist: {}".format(e))
        sys.exit(1)
    set_global_plugin_proxy(plugins)

    # Set up factories.
    usecase_factory = create_sql_usecase_factory(settings, plugins,
                                                 session_factory)
    command_factory = CtlCommandFactory(settings, usecase_factory)

    # Old-style subcommands store a func in callable when setting up their arguments.  New-style
    # subcommands are handled via a factory that constructs and calls the correct object.
    if getattr(args, "func", None):
        args.func(args, settings, session_factory)
    else:
        command = command_factory.construct_command(args.command)
        command.run(args)