コード例 #1
0
def test_file_metrics_writer_context_manager(timestamp, filepath):
    with metrics.SageMakerFileMetricsWriter(filepath) as writer:
        writer.log_metric('foo', value=1.0, timestamp=timestamp)
    entry = json.loads(open(filepath, 'r').read().strip())
    assert {
        'MetricName': 'foo',
        'Value': 1.0,
        'Timestamp': timestamp.timestamp()
    }.items() <= entry.items()
コード例 #2
0
def test_file_metrics_writer_context_manager(timestamp, filepath):
    with metrics.SageMakerFileMetricsWriter(filepath) as writer:
        writer.log_metric("foo", value=1.0, timestamp=timestamp)
    entry = json.loads(open(filepath, "r").read().strip())
    assert {
        "MetricName": "foo",
        "Value": 1.0,
        "Timestamp": timestamp.timestamp()
    }.items() <= entry.items()
コード例 #3
0
    def create(
        cls,
        display_name=None,
        artifact_bucket=None,
        artifact_prefix=None,
        boto3_session=None,
        sagemaker_boto_client=None,
    ):
        """Create a new ``Tracker`` by creating a new trial component.

        Examples
            .. code-block:: python

                from smexperiments import tracker

                my_tracker = tracker.Tracker.create()

        Args:
            display_name: (str, optional). The display name of the trial component to track.
            artifact_bucket: (str, optional) The name of the S3 bucket to store artifacts to.
            artifact_prefix: (str, optional) The prefix to write artifacts to within ``artifact_bucket``
            boto3_session: (boto3.Session, optional) The boto3.Session to use to interact with AWS services.
                If not specified a new default boto3 session will be created.
            sagemaker_boto_client: (boto3.Client, optional) The SageMaker AWS service client to use. If not
                specified a new client will be created from the specified ``boto3_session`` or default
                boto3.Session.

        Returns:
            Tracker: The tracker for the new trial component.
        """
        boto3_session = boto3_session or _utils.boto_session()
        sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client(
        )

        tc = trial_component.TrialComponent.create(
            trial_component_name=_utils.name("TrialComponent"),
            display_name=display_name,
            sagemaker_boto_client=sagemaker_boto_client,
        )

        metrics_writer = metrics.SageMakerFileMetricsWriter()

        return cls(
            tc,
            metrics_writer,
            _ArtifactUploader(tc.trial_component_name, artifact_bucket,
                              artifact_prefix, boto3_session),
        )
コード例 #4
0
def test_file_metrics_writer_log_metric(timestamp, filepath):
    now = datetime.datetime.now(datetime.timezone.utc)
    writer = metrics.SageMakerFileMetricsWriter(filepath)
    writer.log_metric(metric_name="foo", value=1.0)
    writer.log_metric(metric_name="foo", value=2.0, iteration_number=1)
    writer.log_metric(metric_name="foo", value=3.0, timestamp=timestamp)
    writer.log_metric(metric_name="foo", value=4.0, timestamp=timestamp, iteration_number=2)
    writer.close()

    lines = [x for x in open(filepath).read().split("\n") if x]
    [entry_one, entry_two, entry_three, entry_four] = [json.loads(line) for line in lines]

    assert "foo" == entry_one["MetricName"]
    assert 1.0 == entry_one["Value"]
    assert (now.timestamp() - entry_one["Timestamp"]) < 1
    assert "IterationNumber" not in entry_one

    assert 1 == entry_two["IterationNumber"]
    assert timestamp.timestamp() == entry_three["Timestamp"]
    assert 2 == entry_four["IterationNumber"]
コード例 #5
0
def test_file_metrics_writer_flushes_buffer_every_line_log_metric(filepath):
    writer = metrics.SageMakerFileMetricsWriter(filepath)

    writer.log_metric(metric_name="foo", value=1.0)

    lines = [x for x in open(filepath).read().split("\n") if x]
    [entry_one] = [json.loads(line) for line in lines]
    assert "foo" == entry_one["MetricName"]
    assert 1.0 == entry_one["Value"]

    writer.log_metric(metric_name="bar", value=2.0)
    lines = [x for x in open(filepath).read().split("\n") if x]
    [entry_one, entry_two] = [json.loads(line) for line in lines]
    assert "bar" == entry_two["MetricName"]
    assert 2.0 == entry_two["Value"]

    writer.log_metric(metric_name="biz", value=3.0)
    lines = [x for x in open(filepath).read().split("\n") if x]
    [entry_one, entry_two, entry_three] = [json.loads(line) for line in lines]
    assert "biz" == entry_three["MetricName"]
    assert 3.0 == entry_three["Value"]

    writer.close()
コード例 #6
0
    def load(
        cls,
        trial_component_name=None,
        artifact_bucket=None,
        artifact_prefix=None,
        boto3_session=None,
        sagemaker_boto_client=None,
    ):
        """Create a new ``Tracker`` by loading an existing trial component.

        Examples:
            .. code-block:: python

                from smexperiments import tracker

                my_tracker = tracker.Tracker.load(trial_component_name='xgboost')

        Args:
            trial_component_name: (str, optional). The name of the trial component to track. If specified, this
                trial component must exist in SageMaker. If you invoke this method in a running SageMaker training
                or processing job, then trial_component_name can be left empty. In this case, the Tracker will
                resolve the trial component automatically created for your SageMaker Job.
            artifact_bucket: (str, optional) The name of the S3 bucket to store artifacts to.
            artifact_prefix: (str, optional) The prefix to write artifacts to within ``artifact_bucket``
            boto3_session: (boto3.Session, optional) The boto3.Session to use to interact with AWS services.
                If not specified a new default boto3 session will be created.
            sagemaker_boto_client: (boto3.Client, optional) The SageMaker AWS service client to use. If not
                specified a new client will be created from the specified ``boto3_session`` or default
                boto3.Session.

        Returns:
            Tracker: The tracker for the given trial component.

        Raises:
            ValueError: If the trial component failed to load.
        """
        boto3_session = boto3_session or _utils.boto_session()
        sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client(
        )

        tce = _environment.TrialComponentEnvironment.load()

        # Resolve the trial component for this tracker to track: If a trial component name was passed in, then load
        # and track that trial component. Otherwise, try to find a trial component given the current environment,
        # failing if we're unable to load one.
        if trial_component_name:
            tc = trial_component.TrialComponent.load(
                trial_component_name=trial_component_name,
                sagemaker_boto_client=sagemaker_boto_client)
        elif tce:
            tc = tce.get_trial_component(sagemaker_boto_client)
        else:
            raise ValueError(
                'Could not load TrialComponent. Specify a trial_component_name or invoke "create"'
            )

        # if running in a SageMaker context write metrics to file
        if not trial_component_name and tce.environment_type == _environment.EnvironmentType.SageMakerTrainingJob:
            metrics_writer = metrics.SageMakerFileMetricsWriter()
        else:
            metrics_writer = None

        tracker = cls(
            tc,
            metrics_writer,
            _ArtifactUploader(tc.trial_component_name, artifact_bucket,
                              artifact_prefix, boto3_session),
        )
        tracker._in_sagemaker_job = True if tce else False
        return tracker
コード例 #7
0
def test_file_metrics_writer_no_write(filepath):
    writer = metrics.SageMakerFileMetricsWriter(filepath)
    writer.close()
    assert not os.path.exists(filepath)
コード例 #8
0
def test_file_metrics_writer_fail_write_on_close(filepath):
    writer = metrics.SageMakerFileMetricsWriter(filepath)
    writer.log_metric(metric_name='foo', value=1.0)
    writer.close()
    with pytest.raises(metrics.SageMakerMetricsWriterException):
        writer.log_metric(metric_name='foo', value=1.0)