Beispiel #1
0
    def get_models_of_type(model_type: str, root: pathlib.Path) -> List[str]:
        """Get list of model names for requested type in trestle directory."""
        if model_type not in const.MODEL_TYPE_LIST:
            raise err.TrestleError(f'Model type {model_type} is not supported')
        # search relative to project root
        trestle_root = extract_trestle_project_root(root)
        if not trestle_root:
            logger.error(
                f'Given directory {root} is not within a trestle project.')
            raise err.TrestleError(
                'Given directory is not within a trestle project.')

        # contruct path to the model file name
        model_dir_name = ModelUtils.model_type_to_model_dir(model_type)
        root_model_dir = trestle_root / model_dir_name
        model_list = []
        for f in root_model_dir.glob('*/'):
            # only look for proper json and yaml files
            if not ModelUtils._should_ignore(f.stem):
                if not f.is_dir():
                    logger.warning(
                        f'Ignoring validation of misplaced file {f.name} ' +
                        f'found in the model directory, {model_dir_name}.')
                else:
                    model_list.append(f.stem)
        return model_list
Beispiel #2
0
    def get_root_model(module_name: str) -> Tuple[Type[Any], str]:
        """Get the root model class and alias based on the module."""
        try:
            module = importlib.import_module(module_name)
        except ModuleNotFoundError as e:
            raise err.TrestleError(str(e))

        if hasattr(module, 'Model'):
            model_metadata = next(iter(module.Model.__fields__.values()))
            return model_metadata.type_, model_metadata.alias
        raise err.TrestleError('Invalid module')
Beispiel #3
0
    def parse(self, parser):
        """Execute parsing of md token and return nodes."""
        kwargs = None
        expected_heading_level = None
        count = 0
        while parser.stream.current.type != lexer.TOKEN_BLOCK_END:
            count = count + 1
            if count > self.max_tag_parse:
                raise err.TrestleError(
                    'Unexpected Jinja tag structure provided, please review docs.'
                )
            token = parser.stream.current
            if token.test('name:mdsection_include'):
                parser.stream.expect(lexer.TOKEN_NAME)
                markdown_source = parser.stream.expect(lexer.TOKEN_STRING)
                section_title = parser.stream.expect(lexer.TOKEN_STRING)
            elif kwargs is not None:
                arg = token.value
                next(parser.stream)
                parser.stream.expect(lexer.TOKEN_ASSIGN)
                token = parser.stream.current
                exp = self.parse_expression(parser)
                kwargs[arg] = exp.value
            else:
                if parser.stream.look().type == lexer.TOKEN_ASSIGN:
                    kwargs = {}
                continue
        # Use the established environment to source the file
        md_content, _, _ = self.environment.loader.get_source(
            self.environment, markdown_source.value)
        fm = frontmatter.loads(md_content)
        if not fm.metadata == {}:
            logger.warning(
                'Non zero metadata on MD section include - ignoring')
        full_md = markdown_node.MarkdownNode.build_tree_from_markdown(
            fm.content.split('\n'))
        md_section = full_md.get_node_for_key(section_title.value,
                                              strict_matching=True)
        # adjust
        if kwargs is not None:
            expected_heading_level = kwargs.get('heading_level')
        if expected_heading_level is not None:
            level = md_section.get_node_header_lvl()
            delta = int(expected_heading_level) - level
            if not delta == 0:
                md_section.change_header_level_by(delta)
        if not md_section:
            raise err.TrestleError(
                f'Unable to retrieve section "{section_title.value}"" from {markdown_source.value} jinja template.'
            )
        local_parser = Parser(self.environment, md_section.content.raw_text)
        top_level_output = local_parser.parse()

        return top_level_output.body
Beispiel #4
0
    def oscal_read(cls, path: pathlib.Path) -> Optional['OscalBaseModel']:
        """
        Read OSCAL objects.

        Handles the fact OSCAL wraps top level elements and also deals with both yaml and json.

        Args:
            path: The path of the oscal object to read.
        Returns:
            The oscal object read into trestle oscal models.
        """
        # Create the wrapper model.
        alias = classname_to_alias(cls.__name__, AliasMode.JSON)

        content_type = FileContentType.to_content_type(path.suffix)
        logger.debug(
            f'oscal_read content type {content_type} and alias {alias} from {path}'
        )

        if not path.exists():
            logger.warning(f'path does not exist in oscal_read: {path}')
            return None

        obj: Dict[str, Any] = {}
        try:
            if content_type == FileContentType.YAML:
                yaml = YAML(typ='safe')
                fh = path.open('r', encoding=const.FILE_ENCODING)
                obj = yaml.load(fh)
                fh.close()
            elif content_type == FileContentType.JSON:
                obj = load_file(
                    path,
                    json_loads=cls.__config__.json_loads,
                )
        except Exception as e:
            raise err.TrestleError(f'Error loading file {path} {str(e)}')
        try:
            if not len(obj) == 1:
                raise err.TrestleError(
                    f'Invalid OSCAL file structure, oscal file '
                    f'does not have a single top level key wrapping it. It has {len(obj)} keys.'
                )
            parsed = cls.parse_obj(obj[alias])
        except KeyError:
            raise err.TrestleError(
                f'Provided oscal file does not have top level key key: {alias}'
            )
        except Exception as e:
            raise err.TrestleError(f'Error parsing file {path} {str(e)}')

        return parsed
Beispiel #5
0
def get_contextual_file_type(path: pathlib.Path) -> FileContentType:
    """Return the file content type for files in the given directory, if it's a trestle project."""
    if not _is_valid_project_model_path(path):
        raise err.TrestleError(f'Trestle project not found at path {path}')

    for file_or_directory in iterdir_without_hidden_files(path):
        if file_or_directory.is_file():
            return FileContentType.to_content_type(file_or_directory.suffix)

    for file_or_directory in path.iterdir():
        if file_or_directory.is_dir():
            return get_contextual_file_type(file_or_directory)

    raise err.TrestleError('No files found in the project.')
Beispiel #6
0
    def copy_to(self,
                new_oscal_type: Type['OscalBaseModel']) -> 'OscalBaseModel':
        """
        Opportunistic copy operation between similar types of data classes.

        Due to the way in which oscal is constructed we get a set of similar / the same definition across various
        oscal models. Due to the lack of guarantees that they are the same we cannot easily 'collapse' the mode.

        Args:
            new_oscal_type: The desired type of oscal model

        Returns:
            Opportunistic copy of the data into the new model type.
        """
        logger.debug('Copy to started')
        if self.__class__.__name__ == new_oscal_type.__name__:
            logger.debug('Json based copy')
            # Note: Json based oppportunistic copy
            # Dev notes: Do not change this from json. Due to enums (in particular) json is the closest we can get.
            return new_oscal_type.parse_raw(
                self.oscal_serialize_json(pretty=False, wrapped=False))

        if ('__root__' in self.__fields__ and len(self.__fields__) == 1
                and '__root__' in new_oscal_type.__fields__
                and len(new_oscal_type.__fields__) == 1):
            logger.debug('Root element based copy too')
            return new_oscal_type.parse_obj(self.__root__)

        # bad place here.
        raise err.TrestleError(
            'Provided inconsistent classes to copy to methodology.')
Beispiel #7
0
def generate_sample_value_by_type(
    type_: type,
    field_name: str,
) -> Union[datetime, bool, int, str, float, Enum]:
    """Given a type, return sample value.

    Includes the Optional use of passing down a parent_model
    """
    # FIXME: Should be in separate generator module as it inherits EVERYTHING
    if type_ is datetime:
        return datetime.now().astimezone()
    if type_ is bool:
        return False
    if type_ is int:
        return 0
    if type_ is str:
        if field_name == 'oscal_version':
            return OSCAL_VERSION
        return 'REPLACE_ME'
    if type_ is float:
        return 0.00
    if safe_is_sub(type_, ConstrainedStr) or (hasattr(type_, '__name__') and 'ConstrainedStr' in type_.__name__):
        # This code here is messy. we need to meet a set of constraints. If we do
        # TODO: handle regex directly
        if 'uuid' == field_name:
            return str(uuid.uuid4())
        if field_name == 'date_authorized':
            return str(date.today().isoformat())
        if field_name == 'oscal_version':
            return OSCAL_VERSION
        if 'uuid' in field_name:
            return const.SAMPLE_UUID_STR
        # Only case where are UUID is required but not in name.
        if field_name.rstrip('s') == 'member_of_organization':
            return const.SAMPLE_UUID_STR
        return 'REPLACE_ME'
    if hasattr(type_, '__name__') and 'ConstrainedIntValue' in type_.__name__:
        # create an int value as close to the floor as possible does not test upper bound
        multiple = type_.multiple_of if type_.multiple_of else 1  # default to every integer
        # this command is a bit of a problem
        floor = type_.ge if type_.ge else 0
        floor = type_.gt + 1 if type_.gt else floor
        if math.remainder(floor, multiple) == 0:
            return floor
        return (floor + 1) * multiple
    if safe_is_sub(type_, Enum):
        # keys and values diverge due to hypens in oscal names
        return type_(list(type_.__members__.values())[0])
    if type_ is pydantic.networks.EmailStr:
        return pydantic.networks.EmailStr('*****@*****.**')
    if type_ is pydantic.networks.AnyUrl:
        # TODO: Cleanup: this should be usable from a url.. but it's not inuitive.
        return pydantic.networks.AnyUrl('https://sample.com/replaceme.html', scheme='http', host='sample.com')
    if type_ == Any:
        # Return empty dict - aka users can put whatever they want here.
        return {}
    raise err.TrestleError(f'Fatal: Bad type in model {type_}')
Beispiel #8
0
    def add(element_path: ElementPath, parent_element: Element, include_optional: bool) -> None:
        """For a element_path, add a child model to the parent_element of a given parent_model.

        Args:
            element_path: element path of the item to create within the model
            parent_element: the parent element that will host the created element
            include_optional: whether to create optional attributes in the created element

        Notes:
            First we find the child model at the specified element path and instantiate it with default values.
            Then we check if there's already existing element at that path, in which case we append the child model
            to the existing list of dict.
            Then we set up an action plan to update the model (specified by file_path) in memory, create a file
            at the same location and write the file.
            We update the parent_element to prepare for next adds in the chain
        """
        if '*' in element_path.get_full_path_parts():
            raise err.TrestleError('trestle add does not support Wildcard element path.')
        # Get child model
        try:
            child_model = element_path.get_type(type(parent_element.get()))

            # Create child element with sample values
            child_object = gens.generate_sample_model(child_model, include_optional=include_optional)

            if parent_element.get_at(element_path) is not None:
                # The element already exists
                if type(parent_element.get_at(element_path)) is list:
                    child_object = parent_element.get_at(element_path) + child_object
                elif type(parent_element.get_at(element_path)) is dict:
                    child_object = {**parent_element.get_at(element_path), **child_object}
                else:
                    raise err.TrestleError('Already exists and is not a list or dictionary.')

        except Exception as e:
            raise err.TrestleError(f'Bad element path. {str(e)}')

        update_action = UpdateAction(
            sub_element=child_object, dest_element=parent_element, sub_element_path=element_path
        )
        parent_element = parent_element.set_at(element_path, child_object)

        return update_action, parent_element
Beispiel #9
0
    def remove(cls, element_path: ElementPath,
               parent_element: Element) -> Tuple[RemoveAction, Element]:
        """For the element_path, remove a model from the parent_element of a given parent_model.

        First we check if there is an existing element at that path
        If not, we complain.
        Then we set up an action plan to update the model (specified by file_path) in memory,
        return the action and return the parent_element.

        LIMITATIONS:
        1. This does not remove elements of a list or dict. Instead, the entire list or dict is removed.
        2. This cannot remove arbitrarily named elements that are not specified in the schema.
        For example, "responsible-parties" contains named elements, e.g., "organisation". The tool will not
        remove the "organisation" as it is not in the schema, but one can remove its elements, e.g., "party-uuids".
        """
        element_path_list = element_path.get_full_path_parts()
        if '*' in element_path_list:
            raise err.TrestleError(
                'trestle remove does not support Wildcard element path.')

        deleting_element = parent_element.get_at(element_path)

        if deleting_element is not None:
            # The element already exists
            if type(deleting_element) is list:
                logger.warning(
                    'Warning: trestle remove does not support removing elements of a list: '
                    'this removes the entire list')
            elif type(deleting_element) is dict:
                logger.warning(
                    'Warning: trestle remove does not support removing dict elements: '
                    'this removes the entire dict element')
        else:
            raise err.TrestleError(f'Bad element path: {str(element_path)}')

        remove_action = RemoveAction(parent_element, element_path)

        return remove_action, parent_element
Beispiel #10
0
    def get_collection_type(cls) -> Optional[type]:
        """
        If the type wraps an collection, return the collection type.

        Returns:
            The collection type.

        Raises:
            err.TrestleError: if not a wrapper of the collection type.
        """
        if not cls.is_collection_container():
            raise err.TrestleError(
                'OscalBaseModel is not wrapping a collection type')
        return get_origin(cls.__fields__['__root__'].outer_type_)
Beispiel #11
0
def robust_datetime_serialization(input_dt: datetime.datetime) -> str:
    """Return a nicely formatted string for in a format compatible with OSCAL specifications.

    Args:
        input_dt: Input datetime to convert to a string.

    Returns:
        String in isoformat to the millisecond enforcing that timezone offset is provided.

    Raises:
        TrestleError: Error is raised if datetime object does not contain sufficient timezone information.
    """
    # fail if the input datetime is not aware - ie it has no associated timezone
    if input_dt.tzinfo is None:
        raise err.TrestleError('Missing timezone in datetime')
    if input_dt.tzinfo.utcoffset(input_dt) is None:
        raise err.TrestleError('Missing utcoffset in datetime')

    # use this leave in original timezone rather than utc
    # return input_dt.astimezone().isoformat(timespec='milliseconds')  noqa: E800

    # force it to be utc
    return input_dt.astimezone(
        datetime.timezone.utc).isoformat(timespec='milliseconds')
Beispiel #12
0
    def parse(self, parser):
        """Execute parsing of md token and return nodes."""
        kwargs = None
        count = 0
        while parser.stream.current.type != lexer.TOKEN_BLOCK_END:
            count = count + 1
            token = parser.stream.current
            if count > self.max_tag_parse:
                raise err.TrestleError(
                    f'Unexpected Jinja tag structure provided at token {token.value}'
                )
            if token.test('name:md_datestamp'):
                parser.stream.expect(lexer.TOKEN_NAME)
            elif kwargs is not None:
                arg = token.value
                next(parser.stream)
                parser.stream.expect(lexer.TOKEN_ASSIGN)
                token = parser.stream.current
                exp = self.parse_expression(parser)
                kwargs[arg] = exp.value
            else:
                if parser.stream.look(
                ).type == lexer.TOKEN_ASSIGN or parser.stream.look(
                ).type == lexer.TOKEN_STRING:
                    kwargs = {}
                continue

        if kwargs is not None:
            if 'format' in kwargs and type(kwargs['format'] is str):
                date_string = date.today().strftime(kwargs['format'])
            else:
                date_string = date.today().strftime(
                    markdown_const.JINJA_DATESTAMP_FORMAT)
            if 'newline' in kwargs and kwargs['newline'] is False:
                pass
            else:
                date_string += '\n\n'
        else:
            date_string = date.today().strftime(
                markdown_const.JINJA_DATESTAMP_FORMAT) + '\n\n'

        local_parser = Parser(self.environment, date_string)
        datestamp_output = local_parser.parse()

        return datestamp_output.body
Beispiel #13
0
    def create_object(cls, model_alias: str,
                      object_type: Type[TopLevelOscalModel],
                      args: argparse.Namespace) -> int:
        """Create a top level OSCAL object within the trestle directory, leveraging functionality in add."""
        log.set_log_level_from_args(args)
        trestle_root = args.trestle_root  # trestle root is set via command line in args. Default is cwd.
        if not trestle_root or not file_utils.is_valid_project_root(
                args.trestle_root):
            raise err.TrestleRootError(
                f'Given directory {trestle_root} is not a trestle project.')

        plural_path = ModelUtils.model_type_to_model_dir(model_alias)

        desired_model_dir = trestle_root / plural_path / args.output

        desired_model_path = desired_model_dir / (model_alias + '.' +
                                                  args.extension)

        if desired_model_path.exists():
            raise err.TrestleError(
                f'OSCAL file to be created here: {desired_model_path} exists.')

        # Create sample model.
        sample_model = generators.generate_sample_model(
            object_type, include_optional=args.include_optional_fields)
        # Presuming top level level model not sure how to do the typing for this.
        sample_model.metadata.title = f'Generic {model_alias} created by trestle named {args.output}.'  # type: ignore
        sample_model.metadata.last_modified = datetime.now().astimezone()
        sample_model.metadata.oscal_version = trestle.oscal.OSCAL_VERSION
        sample_model.metadata.version = '0.0.0'

        top_element = Element(sample_model, model_alias)

        create_action = CreatePathAction(desired_model_path.resolve(), True)
        write_action = WriteFileAction(
            desired_model_path.resolve(), top_element,
            FileContentType.to_content_type(desired_model_path.suffix))

        # create a plan to write the directory and file.
        create_plan = Plan()
        create_plan.add_action(create_action)
        create_plan.add_action(write_action)
        create_plan.execute()
        return CmdReturnCodes.SUCCESS.value
Beispiel #14
0
def get_inner_type(collection_field_type: Union[Type[List[Any]], Type[Dict[str, Any]]]) -> Type[Any]:
    """Get the inner model in a generic collection model such as a List or a Dict.

    For a dict the return type is of the value and not the key.

    Args:
        collection_field_type: Provided type annotation from a pydantic object

    Returns:
        The desired type.
    """
    try:
        # Pydantic special cases must be dealt with here:
        _, _, singular_type = _get_model_field_info(collection_field_type)
        if singular_type is not None:
            return singular_type
        return typing_extensions.get_args(collection_field_type)[-1]
    except Exception as e:
        logger.debug(e)
        raise err.TrestleError('Model type is not a Dict or List') from e
Beispiel #15
0
    def parse(self, parser):
        """Execute parsing of md token and return nodes."""
        kwargs = None
        expected_heading_level = None
        count = 0
        while parser.stream.current.type != lexer.TOKEN_BLOCK_END:
            count = count + 1
            if count > self.max_tag_parse:
                raise err.TrestleError(
                    'Unexpected Jinja tag structure provided, please review docs.'
                )
            token = parser.stream.current
            if token.test('name:md_clean_include'):
                parser.stream.expect(lexer.TOKEN_NAME)
                markdown_source = parser.stream.expect(lexer.TOKEN_STRING)
            elif kwargs is not None:
                arg = token.value
                next(parser.stream)
                parser.stream.expect(lexer.TOKEN_ASSIGN)
                token = parser.stream.current
                exp = self.parse_expression(parser)
                kwargs[arg] = exp.value
            else:
                if parser.stream.look().type == lexer.TOKEN_ASSIGN:
                    kwargs = {}
                continue
        md_content, _, _ = self.environment.loader.get_source(
            self.environment, markdown_source.value)
        fm = frontmatter.loads(md_content)
        content = fm.content
        content += '\n\n'
        if kwargs is not None:
            expected_heading_level = kwargs.get('heading_level')
        if expected_heading_level is not None:
            content = adjust_heading_level(content, expected_heading_level)

        local_parser = Parser(self.environment, content)
        top_level_output = local_parser.parse()

        return top_level_output.body
Beispiel #16
0
 def model_type_to_model_dir(model_type: str) -> str:
     """Get plural model directory from model type."""
     if model_type not in const.MODEL_TYPE_LIST:
         raise err.TrestleError(f'Not a valid model type: {model_type}.')
     return const.MODEL_TYPE_TO_MODEL_DIR[model_type]
Beispiel #17
0
 def load_file_mock(*args, **kwargs):
     raise err.TrestleError('stuff')
Beispiel #18
0
 def execute_plan_mock(*args, **kwargs):
     raise err.TrestleError('stuff')
Beispiel #19
0
def generate_sample_model(
    model: Union[Type[TG], List[TG], Dict[str, TG]], include_optional: bool = False, depth: int = -1
) -> TG:
    """Given a model class, generate an object of that class with sample values.

    Can generate optional variables with an enabled flag. Any array objects will have a single entry injected into it.

    Note: Trestle generate will not activate recursive loops irrespective of the depth flag.

    Args:
        model: The model type provided. Typically for a user as an OscalBaseModel Subclass.
        include_optional: Whether or not to generate optional fields.
        depth: Depth of the tree at which optional fields are generated. Negative values (default) removes the limit.

    Returns:
        The generated instance with a pro-forma values filled out as best as possible.
    """
    effective_optional = include_optional and not depth == 0

    model_type = model
    # This block normalizes model type down to
    if utils.is_collection_field_type(model):  # type: ignore
        model_type = utils.get_origin(model)  # type: ignore
        model = utils.get_inner_type(model)  # type: ignore
    model = cast(TG, model)

    model_dict = {}
    # this block is needed to avoid situations where an inbuilt is inside a list / dict.
    # the only time dict ever appears is with include_all, which is handled specially
    # the only type of collection possible after OSCAL 1.0.0 is list
    if safe_is_sub(model, OscalBaseModel):
        for field in model.__fields__:
            if field == 'include_all':
                if include_optional:
                    model_dict[field] = {}
                continue
            outer_type = model.__fields__[field].outer_type_
            # next appears to be needed for python 3.7
            if utils.get_origin(outer_type) == Union:
                outer_type = outer_type.__args__[0]
            if model.__fields__[field].required or effective_optional:
                # FIXME could be ForwardRef('SystemComponentStatus')
                if utils.is_collection_field_type(outer_type):
                    inner_type = utils.get_inner_type(outer_type)
                    if inner_type == model:
                        continue
                    model_dict[field] = generate_sample_model(
                        outer_type, include_optional=include_optional, depth=depth - 1
                    )
                elif safe_is_sub(outer_type, OscalBaseModel):
                    model_dict[field] = generate_sample_model(
                        outer_type, include_optional=include_optional, depth=depth - 1
                    )
                else:
                    # Hacking here:
                    # Root models should ideally not exist, however, sometimes we are stuck with them.
                    # If that is the case we need sufficient information on the type in order to generate a model.
                    # E.g. we need the type of the container.
                    if field == '__root__' and hasattr(model, '__name__'):
                        model_dict[field] = generate_sample_value_by_type(
                            outer_type, str_utils.classname_to_alias(model.__name__, AliasMode.FIELD)
                        )
                    else:
                        model_dict[field] = generate_sample_value_by_type(outer_type, field)
        # Note: this assumes list constrains in oscal are always 1 as a minimum size. if two this may still fail.
    else:
        if model_type is list:
            return [generate_sample_value_by_type(model, '')]
        if model_type is dict:
            return {'REPLACE_ME': generate_sample_value_by_type(model, '')}
        raise err.TrestleError('Unhandled collection type.')
    if model_type is list:
        return [model(**model_dict)]
    if model_type is dict:
        return {'REPLACE_ME': model(**model_dict)}
    return model(**model_dict)
Beispiel #20
0
 def simulate_mock(*args, **kwargs):
     raise err.TrestleError('simulate_fail')
Beispiel #21
0
 def oscal_read_mock(*args, **kwargs):
     raise err.TrestleError(logged_error)
Beispiel #22
0
 def get_mock():
     raise err.TrestleError('get fails')
Beispiel #23
0
 def open_sftp_mock():
     raise err.TrestleError('stuff')
Beispiel #24
0
 def validate_import_mock(*args, **kwargs):
     raise err.TrestleError('validate run error')
Beispiel #25
0
 def mock_execute(*args, **kwargs):
     raise err.TrestleError(logged_error)
Beispiel #26
0
 def rollback_mock(*args, **kwargs):
     return [None, err.TrestleError('rollback error')]
Beispiel #27
0
    def get_singular_alias(
            alias_path: str,
            relative_path: Optional[pathlib.Path] = None) -> str:
        """
        Get the alias in the singular form from a jsonpath.

        If contextual_mode is True and contextual_path is None, it assumes alias_path
        is relative to the directory the user is running trestle from.

        Args:
            alias_path: The current alias element path as a string
            relative_path: Optional relative path (w.r.t. trestle_root) to cater for relative element paths.
        Returns:
            Alias as a string
        """
        if len(alias_path.strip()) == 0:
            raise err.TrestleError(f'Invalid jsonpath {alias_path}')

        singular_alias: str = ''

        full_alias_path = alias_path
        if relative_path:
            logger.debug(f'get_singular_alias contextual mode: {str}')
            _, full_model_alias = ModelUtils.get_relative_model_type(
                relative_path)
            first_alias_a = full_model_alias.split('.')[-1]
            first_alias_b = alias_path.split('.')[0]
            if first_alias_a == first_alias_b:
                full_model_alias = '.'.join(full_model_alias.split('.')[:-1])
            full_alias_path = '.'.join([full_model_alias,
                                        alias_path]).strip('.')

        path_parts = full_alias_path.split(const.ALIAS_PATH_SEPARATOR)
        logger.debug(f'path parts: {path_parts}')

        model_types = []

        root_model_alias = path_parts[0]
        found = False
        for module_name in const.MODEL_TYPE_TO_MODEL_MODULE.values():
            model_type, model_alias = ModelUtils.get_root_model(module_name)
            if root_model_alias == model_alias:
                found = True
                model_types.append(model_type)
                break

        if not found:
            raise err.TrestleError(
                f'{root_model_alias} is an invalid root model alias.')

        if len(path_parts) == 1:
            return root_model_alias

        model_type = model_types[0]
        # go through path parts skipping first one
        for i in range(1, len(path_parts)):
            if utils.is_collection_field_type(model_type):
                # if it is a collection type and last part is * then break
                if i == len(path_parts) - 1 and path_parts[i] == '*':
                    break
                # otherwise get the inner type of items in the collection
                model_type = utils.get_inner_type(model_type)
                # and bump i
                i = i + 1
            else:
                path_part = path_parts[i]
                field_map = model_type.alias_to_field_map()
                if path_part not in field_map:
                    continue
                field = field_map[path_part]
                model_type = field.outer_type_
            model_types.append(model_type)

        last_alias = path_parts[-1]
        if last_alias == '*':
            last_alias = path_parts[-2]

        # generic model and not list, so return itself fixme doc
        if not utils.is_collection_field_type(model_type):
            return last_alias

        parent_model_type = model_types[-2]
        try:
            field_map = parent_model_type.alias_to_field_map()
            field = field_map[last_alias]
            outer_type = field.outer_type_
            inner_type = utils.get_inner_type(outer_type)
            inner_type_name = inner_type.__name__
            singular_alias = str_utils.classname_to_alias(
                inner_type_name, AliasMode.JSON)
        except Exception as e:
            raise err.TrestleError(f'Error in json path {alias_path}: {e}')

        return singular_alias
Beispiel #28
0
    def create_stripped_model_type(
        cls,
        stripped_fields: Optional[List[str]] = None,
        stripped_fields_aliases: Optional[List[str]] = None
    ) -> Type['OscalBaseModel']:
        """Create a pydantic model, which is derived from the current model, but missing certain fields.

        OSCAL mandates a 'strict' schema (e.g. unless otherwise stated no additional fields), and certain fields
        are mandatory. Given this the corresponding dataclasses are also strict. Workflows with trestle require missing
        mandatory fields. This allows creation of derivative models missing certain fields.

        Args:
            stripped_fields: The fields to be removed from the current data class.
            stripped_fields_aliases: The fields to be removed from the current data class provided by alias.

        Returns:
            Pydantic data class thta can be used to instanciate a model.

        Raises:
            TrestleError: If user provided both stripped_fields and stripped_field_aliases or neither.
            TrestleError: If incorrect aliases or field names are provided.
        """
        if stripped_fields is not None and stripped_fields_aliases is not None:
            raise err.TrestleError(
                'Either "stripped_fields" or "stripped_fields_aliases" need to be passed, not both.'
            )
        if stripped_fields is None and stripped_fields_aliases is None:
            raise err.TrestleError(
                'Exactly one of "stripped_fields" or "stripped_fields_aliases" must be provided'
            )

        # create alias to field_name mapping
        excluded_fields = []
        if stripped_fields is not None:
            excluded_fields = stripped_fields
        elif stripped_fields_aliases is not None:
            alias_to_field = cls.alias_to_field_map()
            try:
                excluded_fields = [
                    alias_to_field[key].name for key in stripped_fields_aliases
                ]
            except KeyError as e:
                raise err.TrestleError(
                    f'Field {str(e)} does not exist in the model')

        current_fields = cls.__fields__
        new_fields_for_model = {}
        # Build field list
        for current_mfield in current_fields.values():
            if current_mfield.name in excluded_fields:
                continue
            # Validate name in the field
            # Cehcke behaviour with an alias
            if current_mfield.required:
                new_fields_for_model[current_mfield.name] = (
                    current_mfield.outer_type_,
                    Field(...,
                          title=current_mfield.name,
                          alias=current_mfield.alias))
            else:
                new_fields_for_model[current_mfield.name] = (
                    Optional[current_mfield.outer_type_],
                    Field(None,
                          title=current_mfield.name,
                          alias=current_mfield.alias))
        new_model = create_model(cls.__name__,
                                 __base__=OscalBaseModel,
                                 **new_fields_for_model)  # type: ignore
        # TODO: This typing cast should NOT be necessary. Potentially fixable with a fix to pydantic. Issue #175
        new_model = cast(Type[OscalBaseModel], new_model)

        return new_model
Beispiel #29
0
 def ssh_connect_mock():
     err.TrestleError('stuff')
Beispiel #30
0
 def execute_mock(*args, **kwargs):
     raise err.TrestleError('execution failed')