def _search(cls, search_resource, search_item_factory, boto_next_token_name="NextToken", sagemaker_boto_client=None, **kwargs): sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client( ) next_token = None try: while True: search_request_kwargs = _boto_functions.to_boto( kwargs, cls._custom_boto_names, cls._custom_boto_types) search_request_kwargs["Resource"] = search_resource if next_token: search_request_kwargs[boto_next_token_name] = next_token search_method = getattr(sagemaker_boto_client, "search") search_method_response = search_method(**search_request_kwargs) search_items = search_method_response.get("Results", []) next_token = search_method_response.get(boto_next_token_name) for item in search_items: if cls.__name__ in item: yield search_item_factory(item[cls.__name__]) if not next_token: break except StopIteration: return
def _list(cls, boto_list_method, list_item_factory, boto_list_items_name, boto_next_token_name="NextToken", sagemaker_boto_client=None, **kwargs): sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client( ) next_token = None try: while True: list_request_kwargs = _boto_functions.to_boto( kwargs, cls._custom_boto_names, cls._custom_boto_types) if next_token: list_request_kwargs[boto_next_token_name] = next_token list_method = getattr(sagemaker_boto_client, boto_list_method) list_method_response = list_method(**list_request_kwargs) list_items = list_method_response.get(boto_list_items_name, []) next_token = list_method_response.get(boto_next_token_name) for item in list_items: yield list_item_factory(item) if not next_token: break except StopIteration: return
def _construct(cls, boto_method_name, sagemaker_boto_client=None, **kwargs): sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client( ) instance = cls(sagemaker_boto_client, **kwargs) return instance._invoke_api(boto_method_name, kwargs)
def test_sagemaker_client_endpoint_env_set(): current = os.environ.get("SAGEMAKER_ENDPOINT") os.environ["SAGEMAKER_ENDPOINT"] = "https://notexist.amazon.com" client = _utils.sagemaker_client() assert client._endpoint.host == "https://notexist.amazon.com" if current is not None: os.environ["SAGEMAKER_ENDPOINT"] = current
def create( cls, display_name=None, artifact_bucket=None, artifact_prefix=None, boto3_session=None, sagemaker_boto_client=None, ): """Create a new ``Tracker`` by creating a new trial component. Examples .. code-block:: python from smexperiments import tracker my_tracker = tracker.Tracker.create() Args: display_name: (str, optional). The display name of the trial component to track. artifact_bucket: (str, optional) The name of the S3 bucket to store artifacts to. artifact_prefix: (str, optional) The prefix to write artifacts to within ``artifact_bucket`` boto3_session: (boto3.Session, optional) The boto3.Session to use to interact with AWS services. If not specified a new default boto3 session will be created. sagemaker_boto_client: (boto3.Client, optional) The SageMaker AWS service client to use. If not specified a new client will be created from the specified ``boto3_session`` or default boto3.Session. Returns: Tracker: The tracker for the new trial component. """ boto3_session = boto3_session or _utils.boto_session() sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client( ) tc = trial_component.TrialComponent.create( trial_component_name=_utils.name("TrialComponent"), display_name=display_name, sagemaker_boto_client=sagemaker_boto_client, ) metrics_writer = metrics.SageMakerFileMetricsWriter() return cls( tc, metrics_writer, _ArtifactUploader(tc.trial_component_name, artifact_bucket, artifact_prefix, boto3_session), )
def load( cls, trial_component_name=None, artifact_bucket=None, artifact_prefix=None, boto3_session=None, sagemaker_boto_client=None, ): """Create a new ``Tracker`` by loading an existing trial component. Examples: .. code-block:: python from smexperiments import tracker my_tracker = tracker.Tracker.load(trial_component_name='xgboost') Args: trial_component_name: (str, optional). The name of the trial component to track. If specified, this trial component must exist in SageMaker. If you invoke this method in a running SageMaker training or processing job, then trial_component_name can be left empty. In this case, the Tracker will resolve the trial component automatically created for your SageMaker Job. artifact_bucket: (str, optional) The name of the S3 bucket to store artifacts to. artifact_prefix: (str, optional) The prefix to write artifacts to within ``artifact_bucket`` boto3_session: (boto3.Session, optional) The boto3.Session to use to interact with AWS services. If not specified a new default boto3 session will be created. sagemaker_boto_client: (boto3.Client, optional) The SageMaker AWS service client to use. If not specified a new client will be created from the specified ``boto3_session`` or default boto3.Session. Returns: Tracker: The tracker for the given trial component. Raises: ValueError: If the trial component failed to load. """ boto3_session = boto3_session or _utils.boto_session() sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client( ) tce = _environment.TrialComponentEnvironment.load() # Resolve the trial component for this tracker to track: If a trial component name was passed in, then load # and track that trial component. Otherwise, try to find a trial component given the current environment, # failing if we're unable to load one. if trial_component_name: tc = trial_component.TrialComponent.load( trial_component_name=trial_component_name, sagemaker_boto_client=sagemaker_boto_client) elif tce: tc = tce.get_trial_component(sagemaker_boto_client) else: raise ValueError( 'Could not load TrialComponent. Specify a trial_component_name or invoke "create"' ) # if running in a SageMaker context write metrics to file if not trial_component_name and tce.environment_type == _environment.EnvironmentType.SageMakerTrainingJob: metrics_writer = metrics.SageMakerFileMetricsWriter() else: metrics_writer = None tracker = cls( tc, metrics_writer, _ArtifactUploader(tc.trial_component_name, artifact_bucket, artifact_prefix, boto3_session), ) tracker._in_sagemaker_job = True if tce else False return tracker