예제 #1
0
    def import_from_zip(
        cls,
        filename: str,
        artifacts_path: str,
        company_id: Optional[str] = None,
        user_id: str = "",
        user_name: str = "",
    ):
        cls._init_entity_types()

        metadata = None

        with ZipFile(filename) as zfile:
            try:
                with zfile.open(cls.metadata_filename) as f:
                    metadata = json.loads(f.read())

                    meta_public = metadata.get("public")
                    if company_id is None and meta_public is not None:
                        company_id = "" if meta_public else get_default_company(
                        )

                    if not user_id:
                        meta_user_id = metadata.get("user_id", "")
                        meta_user_name = metadata.get("user_name", "")
                        user_id, user_name = meta_user_id, meta_user_name
            except Exception:
                pass

            if not user_id:
                user_id, user_name = "__allegroai__", "Allegro.ai"

            # Make sure we won't end up with an invalid company ID
            if company_id is None:
                company_id = ""

            existing_user = cls.user_cls.objects(id=user_id).only("id").first()
            if not existing_user:
                cls.user_cls(id=user_id, name=user_name,
                             company=company_id).save()

            cls._import(zfile, company_id, user_id, metadata)

        if artifacts_path and os.path.isdir(artifacts_path):
            artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
            if artifacts_file.is_file():
                print(f"Unzipping artifacts into {artifacts_path}")
                with ZipFile(artifacts_file) as zfile:
                    zfile.extractall(artifacts_path)
예제 #2
0
def init_mongo_data():
    try:
        _apply_migrations(log)

        _ensure_uuid()

        company_id = _ensure_company(get_default_company(), "clearml", log)

        _ensure_default_queue(company_id)

        fixed_mode = FixedUser.enabled()

        for user, credentials in config.get("secure.credentials", {}).items():
            user_data = {
                "name": user,
                "role": credentials.role,
                "email": f"{user}@example.com",
                "key": credentials.user_key,
                "secret": credentials.user_secret,
            }
            revoke = fixed_mode and credentials.get("revoke_in_fixed_mode",
                                                    False)
            user_id = _ensure_auth_user(user_data,
                                        company_id,
                                        log=log,
                                        revoke=revoke)
            if credentials.role == Role.user:
                _ensure_backend_user(user_id, company_id,
                                     credentials.display_name)

        if fixed_mode:
            log.info("Fixed users mode is enabled")
            FixedUser.validate()

            if FixedUser.guest_enabled():
                _ensure_company(FixedUser.get_guest_user().company, "guests",
                                log)

            for user in FixedUser.from_config():
                try:
                    ensure_fixed_user(user, log=log)
                except Exception as ex:
                    log.error(f"Failed creating fixed user {user.name}: {ex}")
    except Exception as ex:
        log.exception("Failed initializing mongodb")
예제 #3
0
def pre_populate_data():
    for zip_file in _resolve_zip_files(
            config.get("apiserver.pre_populate.zip_files")):
        _pre_populate(company_id=get_default_company(), zip_file=zip_file)

    PrePopulate.update_featured_projects_order()
예제 #4
0
class FixedUser:
    username: str
    password: str
    name: str
    company: str = get_default_company()

    is_guest: bool = False

    def __attrs_post_init__(self):
        self.user_id = hashlib.md5(
            f"{self.company}:{self.username}".encode()).hexdigest()

    @classmethod
    def enabled(cls):
        return config.get("apiserver.auth.fixed_users.enabled", False)

    @classmethod
    def guest_enabled(cls):
        return cls.enabled() and config.get(
            "services.auth.fixed_users.guest.enabled", False)

    @classmethod
    def validate(cls):
        if not cls.enabled():
            return
        users = cls.from_config()
        if len({user.username for user in users}) < len(users):
            raise FixedUsersError(
                "Duplicate user names found in fixed users configuration")

    @classmethod
    # @lru_cache()
    def from_config(cls) -> Sequence["FixedUser"]:
        users = [
            cls(**user)
            for user in config.get("apiserver.auth.fixed_users.users", [])
        ]

        if cls.guest_enabled():
            users.insert(0, cls.get_guest_user())

        return users

    @classmethod
    @lru_cache()
    def get_by_username(cls, username) -> "FixedUser":
        return next(
            (user for user in cls.from_config() if user.username == username),
            None)

    @classmethod
    @lru_cache()
    def is_guest_endpoint(cls, service, action):
        """
        Validate a potential guest user,
        This method will verify the user is indeed the guest user,
         and that the guest user may access the service/action using its username/password
        """
        return any(ep == ".".join((service, action)) for ep in config.get(
            "services.auth.fixed_users.guest.allow_endpoints", []))

    @classmethod
    def get_guest_user(cls) -> Optional["FixedUser"]:
        if cls.guest_enabled():
            return cls(
                is_guest=True,
                username=config.get(
                    "services.auth.fixed_users.guest.username"),
                password=config.get(
                    "services.auth.fixed_users.guest.password"),
                name=config.get("services.auth.fixed_users.guest.name"),
                company=config.get(
                    "services.auth.fixed_users.guest.default_company"),
            )

    def __hash__(self):
        return hash(self.user_id)