def test_mxnet_with_tensorboard_output_config( sagemaker_session, mxnet_training_latest_version, mxnet_training_latest_py_version, cpu_instance_type, ): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): tensorboard_output_config = TensorBoardOutputConfig( s3_output_path=os.path.join( "s3://", sagemaker_session.default_bucket(), str(uuid.uuid4()), "tensorboard" ) ) script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py") data_path = os.path.join(DATA_DIR, "mxnet_mnist") mx = MXNet( entry_point=script_path, role="SageMakerRole", framework_version=mxnet_training_latest_version, py_version=mxnet_training_latest_py_version, instance_count=1, instance_type=cpu_instance_type, sagemaker_session=sagemaker_session, tensorboard_output_config=tensorboard_output_config, ) train_input = mx.sagemaker_session.upload_data( path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" ) test_input = mx.sagemaker_session.upload_data( path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" ) mx.fit({"train": train_input, "test": test_input}) job_description = mx.latest_training_job.describe() assert ( job_description["TensorBoardOutputConfig"] == tensorboard_output_config._to_request_dict() ) _wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
def test_mxnet_with_all_rules_and_configs( sagemaker_session, mxnet_training_latest_version, mxnet_training_latest_py_version, cpu_instance_type, ): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): rules = [ Rule.sagemaker(rule_configs.vanishing_gradient()), Rule.sagemaker( base_config=rule_configs.all_zero(), rule_parameters={"tensor_regex": ".*"} ), Rule.sagemaker(rule_configs.loss_not_decreasing()), _get_custom_rule(sagemaker_session), ] debugger_hook_config = DebuggerHookConfig( s3_output_path=os.path.join( "s3://", sagemaker_session.default_bucket(), str(uuid.uuid4()), "tensors" ) ) tensorboard_output_config = TensorBoardOutputConfig( s3_output_path=os.path.join( "s3://", sagemaker_session.default_bucket(), str(uuid.uuid4()), "tensorboard" ) ) script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py") data_path = os.path.join(DATA_DIR, "mxnet_mnist") mx = MXNet( entry_point=script_path, role="SageMakerRole", framework_version=mxnet_training_latest_version, py_version=mxnet_training_latest_py_version, instance_count=1, instance_type=cpu_instance_type, sagemaker_session=sagemaker_session, rules=rules, debugger_hook_config=debugger_hook_config, tensorboard_output_config=tensorboard_output_config, ) train_input = mx.sagemaker_session.upload_data( path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" ) test_input = mx.sagemaker_session.upload_data( path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" ) mx.fit({"train": train_input, "test": test_input}) job_description = mx.latest_training_job.describe() for index, rule in enumerate(rules): assert ( job_description["DebugRuleConfigurations"][index]["RuleConfigurationName"] == rule.name ) assert ( job_description["DebugRuleConfigurations"][index]["RuleEvaluatorImage"] == rule.image_uri ) assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict() assert ( job_description["TensorBoardOutputConfig"] == tensorboard_output_config._to_request_dict() ) assert ( _get_rule_evaluation_statuses(job_description) == mx.latest_training_job.rule_job_summary() ) _wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)