Example #1
0
 def _search(cls,
             search_resource,
             search_item_factory,
             boto_next_token_name="NextToken",
             sagemaker_boto_client=None,
             **kwargs):
     sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client(
     )
     next_token = None
     try:
         while True:
             search_request_kwargs = _boto_functions.to_boto(
                 kwargs, cls._custom_boto_names, cls._custom_boto_types)
             search_request_kwargs["Resource"] = search_resource
             if next_token:
                 search_request_kwargs[boto_next_token_name] = next_token
             search_method = getattr(sagemaker_boto_client, "search")
             search_method_response = search_method(**search_request_kwargs)
             search_items = search_method_response.get("Results", [])
             next_token = search_method_response.get(boto_next_token_name)
             for item in search_items:
                 if cls.__name__ in item:
                     yield search_item_factory(item[cls.__name__])
             if not next_token:
                 break
     except StopIteration:
         return
Example #2
0
 def _list(cls,
           boto_list_method,
           list_item_factory,
           boto_list_items_name,
           boto_next_token_name="NextToken",
           sagemaker_boto_client=None,
           **kwargs):
     sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client(
     )
     next_token = None
     try:
         while True:
             list_request_kwargs = _boto_functions.to_boto(
                 kwargs, cls._custom_boto_names, cls._custom_boto_types)
             if next_token:
                 list_request_kwargs[boto_next_token_name] = next_token
             list_method = getattr(sagemaker_boto_client, boto_list_method)
             list_method_response = list_method(**list_request_kwargs)
             list_items = list_method_response.get(boto_list_items_name, [])
             next_token = list_method_response.get(boto_next_token_name)
             for item in list_items:
                 yield list_item_factory(item)
             if not next_token:
                 break
     except StopIteration:
         return
Example #3
0
 def _construct(cls,
                boto_method_name,
                sagemaker_boto_client=None,
                **kwargs):
     sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client(
     )
     instance = cls(sagemaker_boto_client, **kwargs)
     return instance._invoke_api(boto_method_name, kwargs)
def test_sagemaker_client_endpoint_env_set():
    current = os.environ.get("SAGEMAKER_ENDPOINT")
    os.environ["SAGEMAKER_ENDPOINT"] = "https://notexist.amazon.com"

    client = _utils.sagemaker_client()

    assert client._endpoint.host == "https://notexist.amazon.com"

    if current is not None:
        os.environ["SAGEMAKER_ENDPOINT"] = current
    def create(
        cls,
        display_name=None,
        artifact_bucket=None,
        artifact_prefix=None,
        boto3_session=None,
        sagemaker_boto_client=None,
    ):
        """Create a new ``Tracker`` by creating a new trial component.

        Examples
            .. code-block:: python

                from smexperiments import tracker

                my_tracker = tracker.Tracker.create()

        Args:
            display_name: (str, optional). The display name of the trial component to track.
            artifact_bucket: (str, optional) The name of the S3 bucket to store artifacts to.
            artifact_prefix: (str, optional) The prefix to write artifacts to within ``artifact_bucket``
            boto3_session: (boto3.Session, optional) The boto3.Session to use to interact with AWS services.
                If not specified a new default boto3 session will be created.
            sagemaker_boto_client: (boto3.Client, optional) The SageMaker AWS service client to use. If not
                specified a new client will be created from the specified ``boto3_session`` or default
                boto3.Session.

        Returns:
            Tracker: The tracker for the new trial component.
        """
        boto3_session = boto3_session or _utils.boto_session()
        sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client(
        )

        tc = trial_component.TrialComponent.create(
            trial_component_name=_utils.name("TrialComponent"),
            display_name=display_name,
            sagemaker_boto_client=sagemaker_boto_client,
        )

        metrics_writer = metrics.SageMakerFileMetricsWriter()

        return cls(
            tc,
            metrics_writer,
            _ArtifactUploader(tc.trial_component_name, artifact_bucket,
                              artifact_prefix, boto3_session),
        )
    def load(
        cls,
        trial_component_name=None,
        artifact_bucket=None,
        artifact_prefix=None,
        boto3_session=None,
        sagemaker_boto_client=None,
    ):
        """Create a new ``Tracker`` by loading an existing trial component.

        Examples:
            .. code-block:: python

                from smexperiments import tracker

                my_tracker = tracker.Tracker.load(trial_component_name='xgboost')

        Args:
            trial_component_name: (str, optional). The name of the trial component to track. If specified, this
                trial component must exist in SageMaker. If you invoke this method in a running SageMaker training
                or processing job, then trial_component_name can be left empty. In this case, the Tracker will
                resolve the trial component automatically created for your SageMaker Job.
            artifact_bucket: (str, optional) The name of the S3 bucket to store artifacts to.
            artifact_prefix: (str, optional) The prefix to write artifacts to within ``artifact_bucket``
            boto3_session: (boto3.Session, optional) The boto3.Session to use to interact with AWS services.
                If not specified a new default boto3 session will be created.
            sagemaker_boto_client: (boto3.Client, optional) The SageMaker AWS service client to use. If not
                specified a new client will be created from the specified ``boto3_session`` or default
                boto3.Session.

        Returns:
            Tracker: The tracker for the given trial component.

        Raises:
            ValueError: If the trial component failed to load.
        """
        boto3_session = boto3_session or _utils.boto_session()
        sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client(
        )

        tce = _environment.TrialComponentEnvironment.load()

        # Resolve the trial component for this tracker to track: If a trial component name was passed in, then load
        # and track that trial component. Otherwise, try to find a trial component given the current environment,
        # failing if we're unable to load one.
        if trial_component_name:
            tc = trial_component.TrialComponent.load(
                trial_component_name=trial_component_name,
                sagemaker_boto_client=sagemaker_boto_client)
        elif tce:
            tc = tce.get_trial_component(sagemaker_boto_client)
        else:
            raise ValueError(
                'Could not load TrialComponent. Specify a trial_component_name or invoke "create"'
            )

        # if running in a SageMaker context write metrics to file
        if not trial_component_name and tce.environment_type == _environment.EnvironmentType.SageMakerTrainingJob:
            metrics_writer = metrics.SageMakerFileMetricsWriter()
        else:
            metrics_writer = None

        tracker = cls(
            tc,
            metrics_writer,
            _ArtifactUploader(tc.trial_component_name, artifact_bucket,
                              artifact_prefix, boto3_session),
        )
        tracker._in_sagemaker_job = True if tce else False
        return tracker