def create_experiment(Experiment_name, Experiment_description=None):
    try:
        experiment = Experiment.load(experiment_name=Experiment_name)
    except Exception as ex:
        if "ResourceNotFound" in str(ex):
            experiment = Experiment.create(experiment_name=Experiment_name,
                                           description=Experiment_description)
def cleanup_experiment(Experiment_name):
    try:
        experiment = Experiment.load(experiment_name=Experiment_name)
        for trial_summary in experiment.list_trials():
            trial = Trial.load(trial_name=trial_summary.trial_name)
            for trial_component_summary in trial.list_trial_components():
                tc = TrialComponent.load(
                    trial_component_name=trial_component_summary.
                    trial_component_name)
                trial.remove_trial_component(tc)
                try:
                    # comment out to keep trial components
                    tc.delete()
                except:
                    # tc is associated with another trial
                    continue
                # to prevent throttling
                time.sleep(.5)
            trial.delete()
            experiment_name = experiment.experiment_name
        experiment.delete()
    except Exception as ex:
        if 'ResourceNotFound' in str(ex):
            print('%s is a new experiment. Nothing to delete' %
                  Experiment_name)
def test_search(sagemaker_boto_client):
    experiment_names_searched = []
    search_filter = Filter(name="ExperimentName", operator=Operator.CONTAINS, value="smexperiments-integ-")
    search_expression = SearchExpression(filters=[search_filter])
    for s in Experiment.search(
        search_expression=search_expression, max_results=10, sagemaker_boto_client=sagemaker_boto_client
    ):
        experiment_names_searched.append(s.experiment_name)

    assert len(experiment_names_searched) > 0
    assert experiment_names_searched  # sanity test
Example #4
0
def set_experiment_config(experiment_basename=None):
    '''
    Optionally takes an base name for the experiment. Has a hard dependency on boto3 installation. 
    Creates a new experiment using the basename, otherwise simply uses autogluon as basename.
    May run into issues on Experiments' requirements for basename config downstream.
    '''
    now = int(time.time())

    if experiment_basename:
        experiment_name = '{}-autogluon-{}'.format(experiment_basename, now)
    else:
        experiment_name = 'autogluon-{}'.format(now)

    try:
        client = boto3.Session().client('sagemaker')
    except:
        print(
            'You need to install boto3 to create an experiment. Try pip install --upgrade boto3'
        )
        return ''

    try:
        Experiment.create(
            experiment_name=experiment_name,
            description="Running AutoGluon Tabular with SageMaker Experiments",
            sagemaker_boto_client=client)
        print(
            'Created an experiment named {}, you should be able to see this in SageMaker Studio right now.'
            .format(experiment_name))

    except:
        print(
            'Could not create the experiment. Is your basename properly configured? Also try installing the sagemaker experiments SDK with pip install sagemaker-experiments.'
        )
        return ''

    return experiment_name
def cleanup_trial(Experiment_name, Trial_name):
    experiment = Experiment.load(experiment_name=Experiment_name)
    for trial_summary in experiment.list_trials():
        trial = Trial.load(trial_name=trial_summary.trial_name)
        #print(trial_summary.trial_name)
        if trial_summary.trial_name == Trial_name:
            for trial_component_summary in trial.list_trial_components():
                tc = TrialComponent.load(
                    trial_component_name=trial_component_summary.
                    trial_component_name)
                print(trial_component_summary.trial_component_name)
                trial.remove_trial_component(tc)
                try:
                    # comment out to keep trial components
                    tc.delete()
                except:
                    # tc is associated with another trial
                    continue
                # to prevent throttling
                time.sleep(.5)
            trial.delete()
Example #6
0
def _test_training_function(ecr_image, sagemaker_session, instance_type,
                            framework_version, py_version):
    if py_version is None or '2' in py_version:
        pytest.skip('Skipping python2 {}'.format(py_version))
        return

    from smexperiments.experiment import Experiment
    from smexperiments.trial import Trial
    from smexperiments.trial_component import TrialComponent

    sm_client = sagemaker_session.sagemaker_client
    random.seed(f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')}")
    unique_id = random.randint(1, 6000)

    experiment_name = f"tf-container-integ-test-{unique_id}-{int(time.time())}"

    experiment = Experiment.create(
        experiment_name=experiment_name,
        description="Integration test experiment from sagemaker-tf-container",
        sagemaker_boto_client=sm_client,
    )

    trial_name = f"tf-container-integ-test-{unique_id}-{int(time.time())}"
    trial = Trial.create(experiment_name=experiment_name,
                         trial_name=trial_name,
                         sagemaker_boto_client=sm_client)

    training_job_name = utils.unique_name_from_base(
        "test-tf-experiments-mnist")

    # create a training job and wait for it to complete
    with timeout(minutes=15):
        resource_path = os.path.join(os.path.dirname(__file__), "..", "..",
                                     "resources")
        script = os.path.join(resource_path, "mnist", "mnist.py")
        estimator = TensorFlow(
            entry_point=script,
            role="SageMakerRole",
            instance_type=instance_type,
            instance_count=1,
            sagemaker_session=sagemaker_session,
            image_uri=ecr_image,
            framework_version=framework_version,
            script_mode=True,
        )
        inputs = estimator.sagemaker_session.upload_data(
            path=os.path.join(resource_path, "mnist", "data"),
            key_prefix="scriptmode/mnist")
        estimator.fit(inputs, job_name=training_job_name)

    training_job = sm_client.describe_training_job(
        TrainingJobName=training_job_name)
    training_job_arn = training_job["TrainingJobArn"]

    # verify trial component auto created from the training job
    trial_components = list(
        TrialComponent.list(source_arn=training_job_arn,
                            sagemaker_boto_client=sm_client))

    trial_component_summary = trial_components[0]
    trial_component = TrialComponent.load(
        trial_component_name=trial_component_summary.trial_component_name,
        sagemaker_boto_client=sm_client,
    )

    # associate the trial component with the trial
    trial.add_trial_component(trial_component)

    # cleanup
    trial.remove_trial_component(trial_component_summary.trial_component_name)
    trial_component.delete()
    trial.delete()
    experiment.delete()
Example #7
0
def test_training(sagemaker_session, ecr_image, instance_type, instance_count):

    from smexperiments.experiment import Experiment
    from smexperiments.trial import Trial
    from smexperiments.trial_component import TrialComponent

    sm_client = sagemaker_session.sagemaker_client

    experiment_name = "mxnet-container-integ-test-{}".format(int(time.time()))

    experiment = Experiment.create(
        experiment_name=experiment_name,
        description=
        "Integration test experiment from sagemaker-mxnet-container",
        sagemaker_boto_client=sm_client,
    )

    trial_name = "mxnet-container-integ-test-{}".format(int(time.time()))
    trial = Trial.create(experiment_name=experiment_name,
                         trial_name=trial_name,
                         sagemaker_boto_client=sm_client)

    hyperparameters = {
        "random_seed": True,
        "num_steps": 50,
        "smdebug_path": "/opt/ml/output/tensors",
        "epochs": 1,
    }

    mx = MXNet(
        entry_point=SCRIPT_PATH,
        role="SageMakerRole",
        train_instance_count=instance_count,
        train_instance_type=instance_type,
        sagemaker_session=sagemaker_session,
        image_name=ecr_image,
        hyperparameters=hyperparameters,
    )

    training_job_name = utils.unique_name_from_base("test-mxnet-image")

    # create a training job and wait for it to complete
    with timeout(minutes=15):
        prefix = "mxnet_mnist_gluon_basic_hook_demo/{}".format(
            utils.sagemaker_timestamp())
        train_input = mx.sagemaker_session.upload_data(
            path=os.path.join(DATA_PATH, "train"),
            key_prefix=prefix + "/train")
        test_input = mx.sagemaker_session.upload_data(
            path=os.path.join(DATA_PATH, "test"), key_prefix=prefix + "/test")

        mx.fit({
            "train": train_input,
            "test": test_input
        },
               job_name=training_job_name,
               wait=False)

    training_job = sm_client.describe_training_job(
        TrainingJobName=training_job_name)
    training_job_arn = training_job["TrainingJobArn"]

    # verify trial component auto created from the training job
    trial_component_summary = None
    attempts = 0
    while True:
        trial_components = list(
            TrialComponent.list(source_arn=training_job_arn,
                                sagemaker_boto_client=sm_client))

        if len(trial_components) > 0:
            trial_component_summary = trial_components[0]
            break

        if attempts < 10:
            attempts += 1
            sleep(500)

    assert trial_component_summary is not None

    trial_component = TrialComponent.load(
        trial_component_name=trial_component_summary.trial_component_name,
        sagemaker_boto_client=sm_client,
    )

    # associate the trial component with the trial
    trial.add_trial_component(trial_component)

    # cleanup
    trial.remove_trial_component(trial_component_summary.trial_component_name)
    trial_component.delete()
    trial.delete()
    experiment.delete()
Example #8
0
def test_training(sagemaker_session, ecr_image, instance_type,
                  framework_version):

    sm_client = sagemaker_session.sagemaker_client

    experiment_name = "tf-container-integ-test-{}".format(int(time.time()))

    experiment = Experiment.create(
        experiment_name=experiment_name,
        description="Integration test experiment from sagemaker-tf-container",
        sagemaker_boto_client=sm_client,
    )

    trial_name = "tf-container-integ-test-{}".format(int(time.time()))

    trial = Trial.create(experiment_name=experiment_name,
                         trial_name=trial_name,
                         sagemaker_boto_client=sm_client)

    training_job_name = utils.unique_name_from_base(
        "test-tf-experiments-mnist")

    # create a training job and wait for it to complete
    with timeout(minutes=DEFAULT_TIMEOUT):
        resource_path = os.path.join(os.path.dirname(__file__), "..", "..",
                                     "resources")
        script = os.path.join(resource_path, "mnist", "mnist.py")
        estimator = TensorFlow(
            entry_point=script,
            role="SageMakerRole",
            train_instance_type=instance_type,
            train_instance_count=1,
            sagemaker_session=sagemaker_session,
            image_name=ecr_image,
            framework_version=framework_version,
            script_mode=True,
        )
        inputs = estimator.sagemaker_session.upload_data(
            path=os.path.join(resource_path, "mnist", "data"),
            key_prefix="scriptmode/mnist")
        estimator.fit(inputs, job_name=training_job_name)

    training_job = sm_client.describe_training_job(
        TrainingJobName=training_job_name)
    training_job_arn = training_job["TrainingJobArn"]

    # verify trial component auto created from the training job
    trial_components = list(
        TrialComponent.list(source_arn=training_job_arn,
                            sagemaker_boto_client=sm_client))

    trial_component_summary = trial_components[0]
    trial_component = TrialComponent.load(
        trial_component_name=trial_component_summary.trial_component_name,
        sagemaker_boto_client=sm_client,
    )

    # associate the trial component with the trial
    trial.add_trial_component(trial_component)

    # cleanup
    trial.remove_trial_component(trial_component_summary.trial_component_name)
    trial_component.delete()
    trial.delete()
    experiment.delete()
Example #9
0
def main():  # pragma: no cover
    """The main harness that creates or updates and runs the pipeline.

    Creates or updates the pipeline and runs it.
    """
    parser = argparse.ArgumentParser(
        "Creates or updates and runs the pipeline for the pipeline script.")

    parser.add_argument(
        "-n",
        "--module-name",
        dest="module_name",
        type=str,
        help="The module name of the pipeline to import.",
    )
    parser.add_argument(
        "-kwargs",
        "--kwargs",
        dest="kwargs",
        default=None,
        help=
        "Dict string of keyword arguments for the pipeline generation (if supported)",
    )
    parser.add_argument(
        "-role-arn",
        "--role-arn",
        dest="role_arn",
        type=str,
        help="The role arn for the pipeline service execution role.",
    )
    parser.add_argument(
        "-description",
        "--description",
        dest="description",
        type=str,
        default=None,
        help="The description of the pipeline.",
    )
    parser.add_argument(
        "-tags",
        "--tags",
        dest="tags",
        default=None,
        help=
        """List of dict strings of '[{"Key": "string", "Value": "string"}, ..]'""",
    )
    args = parser.parse_args()

    if args.module_name is None or args.role_arn is None:
        parser.print_help()
        sys.exit(2)
    tags = convert_struct(args.tags)

    try:
        pipeline = get_pipeline_driver(args.module_name, args.kwargs)
        print(
            "###### Creating/updating a SageMaker Pipeline with the following definition:"
        )
        parsed = json.loads(pipeline.definition())
        print(json.dumps(parsed, indent=2, sort_keys=True))

        upsert_response = pipeline.upsert(role_arn=args.role_arn,
                                          description=args.description,
                                          tags=tags)
        print(
            "\n###### Created/Updated SageMaker Pipeline: Response received:")
        print(upsert_response)

        execution = pipeline.start()
        print(
            f"\n###### Execution started with PipelineExecutionArn: {execution.arn}"
        )

        # Now we describe execution instance and list the steps in the execution to find out more about the execution.
        execution_run = execution.describe()
        print(execution_run)

        # Create or Load the 'Experiment'
        try:
            experiment = Experiment.create(
                experiment_name=pipeline.name,
                description='Amazon Customer Reviews BERT Pipeline Experiment')
        except:
            experiment = Experiment.load(experiment_name=pipeline.name)

        print('Experiment name: {}'.format(experiment.experiment_name))

        # Add Execution Run as Trial to Experiments
        execution_run_name = execution_run['PipelineExecutionDisplayName']
        print(execution_run_name)

        # Create the `Trial`
        timestamp = int(time.time())

        trial = Trial.create(trial_name=execution_run_name,
                             experiment_name=experiment.experiment_name,
                             sagemaker_boto_client=sm)

        trial_name = trial.trial_name
        print('Trial name: {}'.format(trial_name))

        ######################################################
        ## Parse Pipeline Definition For Processing Job Args
        ######################################################

        processing_param_dict = {}

        for step in parsed['Steps']:
            print('step: {}'.format(step))
            if step['Name'] == 'Processing':
                print('Step Name is Processing...')
                arg_list = step['Arguments']['AppSpecification'][
                    'ContainerArguments']
                print(arg_list)
                num_args = len(arg_list)
                print(num_args)

                # arguments are (key, value) pairs in this list, so we extract them in pairs
                # using [i] and [i+1] indexes and stepping by 2 through the list
                for i in range(0, num_args, 2):
                    key = arg_list[i].replace('--', '')
                    value = arg_list[i + 1]
                    print('arg key: {}'.format(key))
                    print('arg value: {}'.format(value))
                    processing_param_dict[key] = value

        ##############################
        ## Wait For Execution To Finish
        ##############################

        print("Waiting for the execution to finish...")
        execution.wait()
        print("\n#####Execution completed. Execution step details:")

        # List Execution Steps
        print(execution.list_steps())

        # List All Artifacts Generated By The Pipeline
        processing_job_name = None
        training_job_name = None

        from sagemaker.lineage.visualizer import LineageTableVisualizer

        viz = LineageTableVisualizer(sagemaker.session.Session())
        for execution_step in reversed(execution.list_steps()):
            print(execution_step)
            # We are doing this because there appears to be a bug of this LineageTableVisualizer handling the Processing Step
            if execution_step['StepName'] == 'Processing':
                processing_job_name = execution_step['Metadata'][
                    'ProcessingJob']['Arn'].split('/')[-1]
                print(processing_job_name)
                #display(viz.show(processing_job_name=processing_job_name))
            elif execution_step['StepName'] == 'Train':
                training_job_name = execution_step['Metadata']['TrainingJob'][
                    'Arn'].split('/')[-1]
                print(training_job_name)
                #display(viz.show(training_job_name=training_job_name))
            else:
                #display(viz.show(pipeline_execution_step=execution_step))
                time.sleep(5)

        # Add Trial Compontents To Experiment Trial
        processing_job_tc = '{}-aws-processing-job'.format(processing_job_name)
        print(processing_job_tc)

        # -aws-processing-job is the default name assigned by ProcessingJob
        response = sm.associate_trial_component(
            TrialComponentName=processing_job_tc, TrialName=trial_name)

        # -aws-training-job is the default name assigned by TrainingJob
        training_job_tc = '{}-aws-training-job'.format(training_job_name)
        print(training_job_tc)

        response = sm.associate_trial_component(
            TrialComponentName=training_job_tc, TrialName=trial_name)

        ##############
        # Log Additional Parameters within Trial
        ##############
        print('Logging Processing Job Parameters within Experiment Trial...')
        processing_job_tracker = tracker.Tracker.load(
            trial_component_name=processing_job_tc)

        for key, value in processing_param_dict.items():
            print('key: {}, value: {}'.format(key, value))
            processing_job_tracker.log_parameters({key: str(value)})
            # must save after logging
            processing_job_tracker.trial_component.save()

    except Exception as e:  # pylint: disable=W0703
        print(f"Exception: {e}")
        sys.exit(1)
Example #10
0
import time, os, sys
import sagemaker, boto3
import numpy as np
import pandas as pd
import itertools
from pprint import pprint

sess = boto3.Session()
sm = sess.client('sagemaker')
role = sagemaker.get_execution_role()
sagemaker_session = sagemaker.Session(boto_session=sess)

bucket_name = sagemaker_session.default_bucket()

from smexperiments.experiment import Experiment
from smexperiments.trial import Trial
from smexperiments.trial_component import TrialComponent
from smexperiments.tracker import Tracker

training_experiment = Experiment.create(
    experiment_name=f"test-experiment-{int(time.time())}",
    description="This is a Test ",
    sagemaker_boto_client=sm)
Example #11
0
def setup_workflow(project, purpose, workflow_execution_role, script_dir,
                   ecr_repository):
    """ to setup all needed for a step function with sagemaker.
    arg: 
        project: project name under sagemaker
        purpose: subproject
        workflow_execution_role: arn to execute step functions
        script_dir: processing file name, like a .py file
        ecr_repository: ecr repository name
    return:
        workflow: a stepfunctions.workflow.Workflow instance  
    example: 
        PROJECT = '[dpt-proj-2022]'
        PURPOSE = '[processing]'
        WORKFLOW_EXECUTION_ROLE = "arn:aws-cn:iam::[*********]:role/[**************]"
        SCRIPT_DIR = "[processing].py"
        ECR_REPOSITORY = '[ecr-2022]'
    """

    # SageMaker Session setup
    # ========================================================================================
    # SageMaker Session
    # ====================================
    account_id = boto3.client('sts').get_caller_identity().get('Account')
    role = sagemaker.get_execution_role()

    # Storage
    # ====================================
    session = sagemaker.Session()
    region = session.boto_region_name
    s3_output = session.default_bucket()

    # Code storage
    # ==================
    s3_prefix = '{}/{}'.format(project, purpose)
    s3_prefix_code = '{}/code'.format(s3_prefix)
    s3CodePath = 's3://{}/{}/code'.format(s3_output, s3_prefix)

    ## preprocess & prediction
    script_list = [script_dir]

    for script in script_list:
        session.upload_data(script,
                            bucket=session.default_bucket(),
                            key_prefix=s3_prefix_code)

    # ECR environment
    # ====================================
    uri_suffix = 'amazonaws.com.cn'
    tag = ':latest'
    ecr_repository_uri = '{}.dkr.ecr.{}.{}/{}'.format(account_id, region,
                                                      uri_suffix,
                                                      ecr_repository + tag)

    # SageMaker Experiments setup
    # ========================================================================================
    experiment = Experiment.create(
        experiment_name="{}-{}".format(project, int(time.time())),
        description="machine learning project",
        sagemaker_boto_client=boto3.client('sagemaker'))
    print(experiment)

    execution_input = ExecutionInput(schema={
        "ProcessingJobName": str,
        "ResultPath": str,
    })

    # setup script processor
    script_processor = ScriptProcessor(command=['python3'],
                                       image_uri=ecr_repository_uri,
                                       role=role,
                                       instance_count=1,
                                       instance_type='ml.m5.4xlarge')

    # Step
    # ========================================================================================

    optimizing_step = steps.ProcessingStep(
        "Processing Step",
        processor=script_processor,
        job_name=execution_input["ProcessingJobName"],
        inputs=[
            ProcessingInput(source=s3CodePath,
                            destination='/opt/ml/processing/input/code',
                            input_name='code')
        ],
        outputs=[
            ProcessingOutput(output_name=purpose,
                             destination=execution_input["ResultPath"],
                             source='/opt/ml/processing/{}'.format(purpose))
        ],
        container_entrypoint=[
            "python3", "/opt/ml/processing/input/code/" + script_dir
        ],
    )

    # Fail Sate
    # ========================================================================================
    failed_state = steps.states.Fail("Processing Workflow failed",
                                     cause="SageMakerProcessingJobFailed")

    catch_state_processing = steps.states.Catch(
        error_equals=["States.TaskFailed"], next_step=failed_state)

    # Create Workflow
    # ========================================================================================
    optimizing_step.add_catch(catch_state_processing)

    workflow_name = workflow_name = "workflow-{}-{}".format(project,
                                                            purpose).upper()
    workflow_graph = steps.Chain([optimizing_step])

    workflow = Workflow(name=workflow_name,
                        definition=workflow_graph,
                        role=workflow_execution_role)

    workflow.create()
    return workflow
def _test_training_function(ecr_image, sagemaker_session, instance_type,
                            framework_version):
    sm_client = sagemaker_session.sagemaker_client
    random.seed(f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')}")
    unique_id = random.randint(1, 6000)

    experiment_name = f"tf-container-integ-test-{unique_id}-{int(time.time())}"

    experiment = Experiment.create(
        experiment_name=experiment_name,
        description="Integration test experiment from sagemaker-tf-container",
        sagemaker_boto_client=sm_client,
    )

    trial_name = f"tf-container-integ-test-{unique_id}-{int(time.time())}"

    trial = Trial.create(experiment_name=experiment_name,
                         trial_name=trial_name,
                         sagemaker_boto_client=sm_client)

    training_job_name = utils.unique_name_from_base(
        "test-tf-experiments-mnist")

    # create a training job and wait for it to complete
    with timeout(minutes=DEFAULT_TIMEOUT):
        resource_path = os.path.join(os.path.dirname(__file__), "..", "..",
                                     "resources")
        script = os.path.join(resource_path, "mnist", "mnist.py")
        estimator = TensorFlow(
            model_dir=False,
            entry_point=script,
            role="SageMakerRole",
            instance_type=instance_type,
            instance_count=1,
            sagemaker_session=sagemaker_session,
            image_uri=ecr_image,
            framework_version=framework_version,
        )
        inputs = estimator.sagemaker_session.upload_data(
            path=os.path.join(resource_path, "mnist", "data"),
            key_prefix="scriptmode/mnist")
        estimator.fit(inputs, job_name=training_job_name)

    training_job = sm_client.describe_training_job(
        TrainingJobName=training_job_name)
    training_job_arn = training_job["TrainingJobArn"]

    # verify trial component auto created from the training job
    trial_components = list(
        TrialComponent.list(source_arn=training_job_arn,
                            sagemaker_boto_client=sm_client))

    trial_component_summary = trial_components[0]
    trial_component = TrialComponent.load(
        trial_component_name=trial_component_summary.trial_component_name,
        sagemaker_boto_client=sm_client,
    )

    # associate the trial component with the trial
    trial.add_trial_component(trial_component)

    # cleanup
    trial.remove_trial_component(trial_component_summary.trial_component_name)
    trial_component.delete()
    trial.delete()
    # Prevent throttling to avoid deleting experiment before it's updated with trial deletion
    time.sleep(1.2)
    experiment.delete()
def test_training(sagemaker_session, ecr_image, instance_type):

    from smexperiments.experiment import Experiment
    from smexperiments.trial import Trial
    from smexperiments.trial_component import TrialComponent

    sm_client = sagemaker_session.sagemaker_client

    experiment_name = "pytorch-container-integ-test-{}".format(int(
        time.time()))

    experiment = Experiment.create(
        experiment_name=experiment_name,
        description=
        "Integration test full customer e2e from sagemaker-pytorch-container",
        sagemaker_boto_client=sm_client,
    )

    trial_name = "pytorch-container-integ-test-{}".format(int(time.time()))
    trial = Trial.create(experiment_name=experiment_name,
                         trial_name=trial_name,
                         sagemaker_boto_client=sm_client)

    hyperparameters = {
        "random_seed": True,
        "num_steps": 50,
        "smdebug_path": "/opt/ml/output/tensors",
        "epochs": 1,
        "data_dir": training_dir,
    }

    training_job_name = utils.unique_name_from_base(
        "test-pytorch-experiments-image")

    # create a training job and wait for it to complete
    with timeout(minutes=DEFAULT_TIMEOUT):
        pytorch = PyTorch(
            entry_point=smdebug_mnist_script,
            role="SageMakerRole",
            train_instance_count=1,
            train_instance_type=instance_type,
            sagemaker_session=sagemaker_session,
            image_name=ecr_image,
            hyperparameters=hyperparameters,
        )
        training_input = pytorch.sagemaker_session.upload_data(
            path=training_dir, key_prefix="pytorch/mnist")
        pytorch.fit({"training": training_input}, job_name=training_job_name)

    training_job = sm_client.describe_training_job(
        TrainingJobName=training_job_name)
    training_job_arn = training_job["TrainingJobArn"]

    # verify trial component auto created from the training job
    trial_components = list(
        TrialComponent.list(source_arn=training_job_arn,
                            sagemaker_boto_client=sm_client))

    trial_component_summary = trial_components[0]
    trial_component = TrialComponent.load(
        trial_component_name=trial_component_summary.trial_component_name,
        sagemaker_boto_client=sm_client,
    )

    # associate the trial component with the trial
    trial.add_trial_component(trial_component)

    # cleanup
    trial.remove_trial_component(trial_component_summary.trial_component_name)
    trial_component.delete()
    trial.delete()
    experiment.delete()