def test_mxnet_with_custom_rule_and_debugger_hook_config(
    sagemaker_session,
    mxnet_training_latest_version,
    mxnet_training_latest_py_version,
    cpu_instance_type,
):
    with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
        rules = [_get_custom_rule(sagemaker_session)]
        debugger_hook_config = DebuggerHookConfig(
            s3_output_path=os.path.join(
                "s3://", sagemaker_session.default_bucket(), str(uuid.uuid4()), "tensors"
            )
        )

        script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")
        data_path = os.path.join(DATA_DIR, "mxnet_mnist")

        mx = MXNet(
            entry_point=script_path,
            role="SageMakerRole",
            framework_version=mxnet_training_latest_version,
            py_version=mxnet_training_latest_py_version,
            instance_count=1,
            instance_type=cpu_instance_type,
            sagemaker_session=sagemaker_session,
            rules=rules,
            debugger_hook_config=debugger_hook_config,
        )

        train_input = mx.sagemaker_session.upload_data(
            path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
        )
        test_input = mx.sagemaker_session.upload_data(
            path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
        )

        mx.fit({"train": train_input, "test": test_input})

        job_description = mx.latest_training_job.describe()

        for index, rule in enumerate(rules):
            assert (
                job_description["DebugRuleConfigurations"][index]["RuleConfigurationName"]
                == rule.name
            )
            assert (
                job_description["DebugRuleConfigurations"][index]["RuleEvaluatorImage"]
                == rule.image_uri
            )
            assert job_description["DebugRuleConfigurations"][index]["VolumeSizeInGB"] == 30
        assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict()

        assert (
            _get_rule_evaluation_statuses(job_description)
            == mx.latest_training_job.rule_job_summary()
        )

        _wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
def test_mxnet_with_debugger_hook_config(
    sagemaker_session,
    mxnet_training_latest_version,
    mxnet_training_latest_py_version,
    cpu_instance_type,
):
    with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
        debugger_hook_config = DebuggerHookConfig(
            s3_output_path=os.path.join(
                "s3://", sagemaker_session.default_bucket(), str(uuid.uuid4()), "tensors"
            )
        )

        script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")
        data_path = os.path.join(DATA_DIR, "mxnet_mnist")

        mx = MXNet(
            entry_point=script_path,
            role="SageMakerRole",
            framework_version=mxnet_training_latest_version,
            py_version=mxnet_training_latest_py_version,
            instance_count=1,
            instance_type=cpu_instance_type,
            sagemaker_session=sagemaker_session,
            debugger_hook_config=debugger_hook_config,
        )

        train_input = mx.sagemaker_session.upload_data(
            path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
        )
        test_input = mx.sagemaker_session.upload_data(
            path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
        )

        mx.fit({"train": train_input, "test": test_input})

        job_description = mx.latest_training_job.describe()
        assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict()

        _wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
Beispiel #3
0
def test_debug_hook_disabled_with_checkpointing(
    sagemaker_session,
    mxnet_training_latest_version,
    mxnet_training_latest_py_version,
    cpu_instance_type,
):
    with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
        s3_output_path = os.path.join("s3://",
                                      sagemaker_session.default_bucket(),
                                      str(uuid.uuid4()))
        debugger_hook_config = DebuggerHookConfig(
            s3_output_path=os.path.join(s3_output_path, "tensors"))

        script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")

        # Estimator with checkpointing enabled
        mx = MXNet(
            entry_point=script_path,
            role="SageMakerRole",
            framework_version=mxnet_training_latest_version,
            py_version=mxnet_training_latest_py_version,
            instance_count=1,
            instance_type=cpu_instance_type,
            sagemaker_session=sagemaker_session,
            debugger_hook_config=debugger_hook_config,
            checkpoint_local_path="/opt/ml/checkpoints",
            checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
        )
        mx._prepare_for_training()

        # Debug Hook should be enabled
        assert mx.debugger_hook_config is not None

        # Estimator with checkpointing enabled and Instance Count>1
        mx = MXNet(
            entry_point=script_path,
            role="SageMakerRole",
            framework_version=mxnet_training_latest_version,
            py_version=mxnet_training_latest_py_version,
            instance_count=2,
            instance_type=cpu_instance_type,
            sagemaker_session=sagemaker_session,
            debugger_hook_config=debugger_hook_config,
            checkpoint_local_path="/opt/ml/checkpoints",
            checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
        )
        mx._prepare_for_training()
        # Debug Hook should be disabled
        assert mx.debugger_hook_config is False

        # Estimator with checkpointing enabled and SMDataParallel Enabled
        pt = PyTorch(
            base_job_name="pytorch-smdataparallel-mnist",
            entry_point=script_path,
            role="SageMakerRole",
            framework_version="1.8.0",
            py_version="py36",
            instance_count=1,
            # For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
            instance_type="ml.p3.16xlarge",
            sagemaker_session=sagemaker_session,
            # Training using SMDataParallel Distributed Training Framework
            distribution={
                "smdistributed": {
                    "dataparallel": {
                        "enabled": True
                    }
                }
            },
            checkpoint_local_path="/opt/ml/checkpoints",
            checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
        )
        pt._prepare_for_training()
        # Debug Hook should be disabled
        assert pt.debugger_hook_config is False

        # Estimator with checkpointing enabled and SMModelParallel Enabled
        tf = TensorFlow(
            base_job_name="tf-smdataparallel-mnist",
            entry_point=script_path,
            role="SageMakerRole",
            framework_version="2.4.1",
            py_version="py36",
            instance_count=1,
            # For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
            instance_type="ml.p3.16xlarge",
            sagemaker_session=sagemaker_session,
            # Training using SMDataParallel Distributed Training Framework
            distribution={
                "smdistributed": {
                    "modelparallel": {
                        "enabled": True
                    }
                }
            },
            checkpoint_local_path="/opt/ml/checkpoints",
            checkpoint_s3_uri=os.path.join(s3_output_path, "checkpoints"),
        )
        tf._prepare_for_training()
        # Debug Hook should be disabled
        assert tf.debugger_hook_config is False

        # Estimator with checkpointing enabled with Xgboost Estimator
        xg = XGBoost(
            base_job_name="test_xgboost",
            entry_point=script_path,
            role="SageMakerRole",
            framework_version="1.2-1",
            py_version="py3",
            instance_count=2,
            # For training with p3dn instance use - ml.p3dn.24xlarge, with p4dn instance use - ml.p4d.24xlarge
            instance_type="ml.p3.16xlarge",
            sagemaker_session=sagemaker_session,
            # Training using SMDataParallel Distributed Training Framework
        )
        xg._prepare_for_training()
        # Debug Hook should be enabled
        assert xg.debugger_hook_config is not None