def _search(
        cls,
        search_resource,
        search_item_factory,
        boto_next_token_name="NextToken",
        sagemaker_session=None,
        **kwargs
    ):
        """Search for objects with the SageMaker API."""
        sagemaker_session = sagemaker_session or _utils.default_session()
        sagemaker_client = sagemaker_session.sagemaker_client

        next_token = None
        try:
            while True:
                search_request_kwargs = _boto_functions.to_boto(
                    kwargs, cls._custom_boto_names, cls._custom_boto_types
                )
                search_request_kwargs["Resource"] = search_resource
                if next_token:
                    search_request_kwargs[boto_next_token_name] = next_token
                search_method = getattr(sagemaker_client, "search")
                search_method_response = search_method(**search_request_kwargs)
                search_items = search_method_response.get("Results", [])
                next_token = search_method_response.get(boto_next_token_name)
                for item in search_items:
                    if cls.__name__ in item:
                        yield search_item_factory(item[cls.__name__])
                if not next_token:
                    break
        except StopIteration:
            return
 def _list(
     cls,
     boto_list_method,
     list_item_factory,
     boto_list_items_name,
     boto_next_token_name="NextToken",
     sagemaker_session=None,
     **kwargs
 ):
     """List objects from the SageMaker API."""
     sagemaker_session = sagemaker_session or _utils.default_session()
     sagemaker_client = sagemaker_session.sagemaker_client
     next_token = None
     try:
         while True:
             list_request_kwargs = _boto_functions.to_boto(
                 kwargs, cls._custom_boto_names, cls._custom_boto_types
             )
             if next_token:
                 list_request_kwargs[boto_next_token_name] = next_token
             list_method = getattr(sagemaker_client, boto_list_method)
             list_method_response = list_method(**list_request_kwargs)
             list_items = list_method_response.get(boto_list_items_name, [])
             next_token = list_method_response.get(boto_next_token_name)
             for item in list_items:
                 yield list_item_factory(item)
             if not next_token:
                 break
     except StopIteration:
         return
Ejemplo n.º 3
0
    def downstream_trials(self, sagemaker_session=None) -> list:
        """Retrieve all trial runs which that use this artifact.

        Args:
            sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session
                will be created.

        Returns:
            [Trial]: A list of SageMaker `Trial` objects.
        """
        # don't specify destination type because for Trial Components it could be one of
        # SageMaker[TrainingJob|ProcessingJob|TransformJob|ExperimentTrialComponent]
        outgoing_associations: Iterator = Association.list(
            source_arn=self.artifact_arn, sagemaker_session=sagemaker_session
        )
        trial_component_arns: list = list(map(lambda x: x.destination_arn, outgoing_associations))

        if not trial_component_arns:
            # no outgoing associations for this artifact
            return []

        get_module("smexperiments")
        from smexperiments import trial_component, search_expression

        max_search_by_arn: int = 60
        num_search_batches = math.ceil(len(trial_component_arns) % max_search_by_arn)
        trial_components: list = []

        sagemaker_session = sagemaker_session or _utils.default_session()
        sagemaker_client = sagemaker_session.sagemaker_client

        for i in range(num_search_batches):
            start: int = i * max_search_by_arn
            end: int = start + max_search_by_arn
            arn_batch: list = trial_component_arns[start:end]
            se: Any = self._get_search_expression(arn_batch, search_expression)
            search_result: Any = trial_component.TrialComponent.search(
                search_expression=se, sagemaker_boto_client=sagemaker_client
            )

            trial_components: list = trial_components + list(search_result)

        trials: set = set()

        for tc in list(trial_components):
            for parent in tc.parents:
                trials.add(parent["TrialName"])

        return list(trials)
Ejemplo n.º 4
0
    def _get_trial_from_trial_component(self,
                                        trial_component_arns: list) -> List:
        """Retrieve all upstream trial runs which that use the trial component arns.

        Args:
            trial_component_arns (list): list of trial component arns

        Returns:
            [Trial]: A list of SageMaker `Trial` objects.
        """
        if not trial_component_arns:
            # no outgoing associations for this artifact
            return []

        get_module("smexperiments")
        from smexperiments import trial_component, search_expression

        max_search_by_arn: int = 60
        num_search_batches = math.ceil(
            len(trial_component_arns) % max_search_by_arn)
        trial_components: list = []

        sagemaker_session = self.sagemaker_session or _utils.default_session()
        sagemaker_client = sagemaker_session.sagemaker_client

        for i in range(num_search_batches):
            start: int = i * max_search_by_arn
            end: int = start + max_search_by_arn
            arn_batch: list = trial_component_arns[start:end]
            se: Any = self._get_search_expression(arn_batch, search_expression)
            search_result: Any = trial_component.TrialComponent.search(
                search_expression=se, sagemaker_boto_client=sagemaker_client)

            trial_components: list = trial_components + list(search_result)

        trials: set = set()

        for tc in list(trial_components):
            for parent in tc.parents:
                trials.add(parent["TrialName"])

        return list(trials)
 def _construct(cls, boto_method_name, sagemaker_session=None, **kwargs):
     """Create and invoke a SageMaker API call request."""
     sagemaker_session = sagemaker_session or _utils.default_session()
     instance = cls(sagemaker_session, **kwargs)
     return instance._invoke_api(boto_method_name, kwargs)