def _search( cls, search_resource, search_item_factory, boto_next_token_name="NextToken", sagemaker_session=None, **kwargs ): """Search for objects with the SageMaker API.""" sagemaker_session = sagemaker_session or _utils.default_session() sagemaker_client = sagemaker_session.sagemaker_client next_token = None try: while True: search_request_kwargs = _boto_functions.to_boto( kwargs, cls._custom_boto_names, cls._custom_boto_types ) search_request_kwargs["Resource"] = search_resource if next_token: search_request_kwargs[boto_next_token_name] = next_token search_method = getattr(sagemaker_client, "search") search_method_response = search_method(**search_request_kwargs) search_items = search_method_response.get("Results", []) next_token = search_method_response.get(boto_next_token_name) for item in search_items: if cls.__name__ in item: yield search_item_factory(item[cls.__name__]) if not next_token: break except StopIteration: return
def _list( cls, boto_list_method, list_item_factory, boto_list_items_name, boto_next_token_name="NextToken", sagemaker_session=None, **kwargs ): """List objects from the SageMaker API.""" sagemaker_session = sagemaker_session or _utils.default_session() sagemaker_client = sagemaker_session.sagemaker_client next_token = None try: while True: list_request_kwargs = _boto_functions.to_boto( kwargs, cls._custom_boto_names, cls._custom_boto_types ) if next_token: list_request_kwargs[boto_next_token_name] = next_token list_method = getattr(sagemaker_client, boto_list_method) list_method_response = list_method(**list_request_kwargs) list_items = list_method_response.get(boto_list_items_name, []) next_token = list_method_response.get(boto_next_token_name) for item in list_items: yield list_item_factory(item) if not next_token: break except StopIteration: return
def downstream_trials(self, sagemaker_session=None) -> list: """Retrieve all trial runs which that use this artifact. Args: sagemaker_session (obj): Sagemaker Sesssion to use. If not provided a default session will be created. Returns: [Trial]: A list of SageMaker `Trial` objects. """ # don't specify destination type because for Trial Components it could be one of # SageMaker[TrainingJob|ProcessingJob|TransformJob|ExperimentTrialComponent] outgoing_associations: Iterator = Association.list( source_arn=self.artifact_arn, sagemaker_session=sagemaker_session ) trial_component_arns: list = list(map(lambda x: x.destination_arn, outgoing_associations)) if not trial_component_arns: # no outgoing associations for this artifact return [] get_module("smexperiments") from smexperiments import trial_component, search_expression max_search_by_arn: int = 60 num_search_batches = math.ceil(len(trial_component_arns) % max_search_by_arn) trial_components: list = [] sagemaker_session = sagemaker_session or _utils.default_session() sagemaker_client = sagemaker_session.sagemaker_client for i in range(num_search_batches): start: int = i * max_search_by_arn end: int = start + max_search_by_arn arn_batch: list = trial_component_arns[start:end] se: Any = self._get_search_expression(arn_batch, search_expression) search_result: Any = trial_component.TrialComponent.search( search_expression=se, sagemaker_boto_client=sagemaker_client ) trial_components: list = trial_components + list(search_result) trials: set = set() for tc in list(trial_components): for parent in tc.parents: trials.add(parent["TrialName"]) return list(trials)
def _get_trial_from_trial_component(self, trial_component_arns: list) -> List: """Retrieve all upstream trial runs which that use the trial component arns. Args: trial_component_arns (list): list of trial component arns Returns: [Trial]: A list of SageMaker `Trial` objects. """ if not trial_component_arns: # no outgoing associations for this artifact return [] get_module("smexperiments") from smexperiments import trial_component, search_expression max_search_by_arn: int = 60 num_search_batches = math.ceil( len(trial_component_arns) % max_search_by_arn) trial_components: list = [] sagemaker_session = self.sagemaker_session or _utils.default_session() sagemaker_client = sagemaker_session.sagemaker_client for i in range(num_search_batches): start: int = i * max_search_by_arn end: int = start + max_search_by_arn arn_batch: list = trial_component_arns[start:end] se: Any = self._get_search_expression(arn_batch, search_expression) search_result: Any = trial_component.TrialComponent.search( search_expression=se, sagemaker_boto_client=sagemaker_client) trial_components: list = trial_components + list(search_result) trials: set = set() for tc in list(trial_components): for parent in tc.parents: trials.add(parent["TrialName"]) return list(trials)
def _construct(cls, boto_method_name, sagemaker_session=None, **kwargs): """Create and invoke a SageMaker API call request.""" sagemaker_session = sagemaker_session or _utils.default_session() instance = cls(sagemaker_session, **kwargs) return instance._invoke_api(boto_method_name, kwargs)