def __init__(self, role, train_instance_count, train_instance_type, train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File', output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None): """Initialize an ``EstimatorBase`` instance. Args: role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. train_instance_count (int): Number of Amazon EC2 instances to use for training. train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. train_volume_size (int): Size in GB of the EBS volume to use for storing input data during training (default: 30). Must be large enough to store training data if File Mode is used (which is the default). train_max_run (int): Timeout in seconds for training (default: 24 * 60 * 60). After this amount of time Amazon SageMaker terminates the job regardless of its current status. input_mode (str): The input mode that the algorithm supports (default: 'File'). Valid modes: 'File' - Amazon SageMaker copies the training dataset from the S3 location to a local directory. 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe. output_path (str): S3 location for saving the trainig result (model artifacts and output files). If not specified, results are stored to a default bucket. If the bucket with the specific name does not exist, the estimator creates the bucket during the :meth:`~sagemaker.estimator.EstimatorBase.fit` method execution. output_kms_key (str): Optional. KMS key ID for encrypting the training output (default: None). base_job_name (str): Prefix for training job name when the :meth:`~sagemaker.estimator.EstimatorBase.fit` method launches. If not specified, the estimator generates a default job name, based on the training image name and current timestamp. sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. """ self.role = role self.train_instance_count = train_instance_count self.train_instance_type = train_instance_type self.train_volume_size = train_volume_size self.train_max_run = train_max_run self.input_mode = input_mode if self.train_instance_type in ('local', 'local_gpu'): self.local_mode = True if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1: raise RuntimeError("Distributed Training in Local GPU is not supported") self.sagemaker_session = LocalSession() else: self.local_mode = False self.sagemaker_session = sagemaker_session or Session() self.base_job_name = base_job_name self._current_job_name = None self.output_path = output_path self.output_kms_key = output_kms_key self.latest_training_job = None
def __init__(self, instance_type, instance_count, image, sagemaker_session=None): """Initialize a SageMakerContainer instance It uses a :class:`sagemaker.session.Session` for general interaction with user configuration such as getting the default sagemaker S3 bucket. However this class does not call any of the SageMaker APIs. Args: instance_type (str): The instance type to use. Either 'local' or 'local_gpu' instance_count (int): The number of instances to create. image (str): docker image to use. sagemaker_session (sagemaker.session.Session): a sagemaker session to use when interacting with SageMaker. """ from sagemaker.local.local_session import LocalSession self.sagemaker_session = sagemaker_session or LocalSession() self.instance_type = instance_type self.instance_count = instance_count self.image = image # Since we are using a single docker network, Generate a random suffix to attach to the container names. # This way multiple jobs can run in parallel. suffix = ''.join( random.choice(string.ascii_uppercase + string.digits) for _ in range(5)) self.hosts = [ '{}-{}-{}'.format(CONTAINER_PREFIX, i, suffix) for i in range(1, self.instance_count + 1) ] self.container_root = None self.container = None
def __init__( self, instance_type, instance_count, image, sagemaker_session=None, container_entrypoint=None, container_arguments=None, ): """Initialize a SageMakerContainer instance It uses a :class:`sagemaker.session.Session` for general interaction with user configuration such as getting the default sagemaker S3 bucket. However this class does not call any of the SageMaker APIs. Args: instance_type (str): The instance type to use. Either 'local' or 'local_gpu' instance_count (int): The number of instances to create. image (str): docker image to use. sagemaker_session (sagemaker.session.Session): a sagemaker session to use when interacting with SageMaker. container_entrypoint (str): the container entrypoint to execute container_arguments (str): the container entrypoint arguments """ from sagemaker.local.local_session import LocalSession # check if docker-compose is installed if find_executable("docker-compose") is None: raise ImportError( "'docker-compose' is not installed. " "Local Mode features will not work without docker-compose. " "For more information on how to install 'docker-compose', please, see " "https://docs.docker.com/compose/install/") self.sagemaker_session = sagemaker_session or LocalSession() self.instance_type = instance_type self.instance_count = instance_count self.image = image self.container_entrypoint = container_entrypoint self.container_arguments = container_arguments # Since we are using a single docker network, Generate a random suffix to attach to the # container names. This way multiple jobs can run in parallel. suffix = "".join( random.choice(string.ascii_lowercase + string.digits) for _ in range(5)) self.hosts = [ "{}-{}-{}".format(CONTAINER_PREFIX, i, suffix) for i in range(1, self.instance_count + 1) ] self.container_root = None self.container = None
def __init__( self, model_db_client: ModelDbClient, experiment_id, model_id, image=None, role=None, instance_config={}, boto_session=None, algor_config={}, train_state=None, evaluation_job_name=None, eval_state=None, eval_scores={}, input_model_id=None, rl_estimator=None, input_data_s3_prefix=None, manifest_file_path=None, eval_data_s3_path=None, s3_model_output_path=None, training_start_time=None, training_end_time=None, ): """Initialize a model entity in the current experiment Args: model_db_client (ModelDBClient): A DynamoDB client to query the model table. The 'Model' entity use this client to read/update the model state. experiment_id (str): A unique id for the experiment. The created/loaded model will be associated with the given experiment. model_id (str): Aa unique id for the model. The model table uses model id to manage associated model metadata. image (str): The container image to use for training/evaluation. role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs will use this role to access AWS resources. instance_config (dict): A dictionary that specify the resource configuration for the model training/evaluation job. boto_session (boto3.session.Session): A session stores configuration state and allows you to create service clients and resources. algor_config (dict): A dictionary that specify the algorithm type and hyper parameters of the training/evaluation job. train_state (str): State of the model training job. evaluation_job_name (str): Job name for Latest Evaluation Job for this model eval_state (str): State of the model evaluation job. input_model_id (str): A unique model id to specify which model to use as a pre-trained model for the model training job. rl_estimator (sagemaker.rl.estimator.RLEstimator): A Sagemaker RLEstimator entity that handle Reinforcement Learning (RL) execution within a SageMaker Training Job. input_data_s3_prefix (str): Input data path for the data source of the model training job. s3_model_output_path (str): Output data path of model artifact for the model training job. training_start_time (str): Starting timestamp of the model training job. training_end_time (str): Finished timestamp of the model training job. Returns: orchestrator.model_manager.ModelManager: A ``Model`` object associated with the given experiment. """ self.model_db_client = model_db_client self.experiment_id = experiment_id self.model_id = model_id # Currently we are not storing image/role and other model params in ModelDb self.image = image self.role = role self.instance_config = instance_config self.algor_config = algor_config # load configs self.instance_type = self.instance_config.get("instance_type", "local") self.instance_count = self.instance_config.get("instance_count", 1) self.algor_params = self.algor_config.get("algorithms_parameters", {}) # create a local ModelRecord object. self.model_record = ModelRecord( experiment_id, model_id, train_state, evaluation_job_name, eval_state, eval_scores, input_model_id, input_data_s3_prefix, manifest_file_path, eval_data_s3_path, s3_model_output_path, training_start_time, training_end_time, ) # try to save this record file. if it throws RecordAlreadyExistsException # reload the record from ModelDb, and recreate try: self.model_db_client.create_new_model_record( self.model_record.to_ddb_record()) except RecordAlreadyExistsException: logger.debug("Model already exists. Reloading from model record.") model_record = self.model_db_client.get_model_record( experiment_id, model_id) self.model_record = ModelRecord.load_from_ddb_record(model_record) except Exception as e: logger.error("Unhandled Exception! " + str(e)) raise UnhandledWorkflowException( "Something went wrong while creating a new model") if boto_session is None: boto_session = boto3.Session() self.boto_session = boto_session if self.instance_type == "local": self.sagemaker_session = LocalSession() else: self.sagemaker_session = sagemaker.session.Session( self.boto_session) self.sagemaker_client = self.sagemaker_session.sagemaker_client
class ModelManager: """A model entity with the given experiment. This class will handle the model creation, model training, model evaluation and model metadata management. """ def __init__( self, model_db_client: ModelDbClient, experiment_id, model_id, image=None, role=None, instance_config={}, boto_session=None, algor_config={}, train_state=None, evaluation_job_name=None, eval_state=None, eval_scores={}, input_model_id=None, rl_estimator=None, input_data_s3_prefix=None, manifest_file_path=None, eval_data_s3_path=None, s3_model_output_path=None, training_start_time=None, training_end_time=None, ): """Initialize a model entity in the current experiment Args: model_db_client (ModelDBClient): A DynamoDB client to query the model table. The 'Model' entity use this client to read/update the model state. experiment_id (str): A unique id for the experiment. The created/loaded model will be associated with the given experiment. model_id (str): Aa unique id for the model. The model table uses model id to manage associated model metadata. image (str): The container image to use for training/evaluation. role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs will use this role to access AWS resources. instance_config (dict): A dictionary that specify the resource configuration for the model training/evaluation job. boto_session (boto3.session.Session): A session stores configuration state and allows you to create service clients and resources. algor_config (dict): A dictionary that specify the algorithm type and hyper parameters of the training/evaluation job. train_state (str): State of the model training job. evaluation_job_name (str): Job name for Latest Evaluation Job for this model eval_state (str): State of the model evaluation job. input_model_id (str): A unique model id to specify which model to use as a pre-trained model for the model training job. rl_estimator (sagemaker.rl.estimator.RLEstimator): A Sagemaker RLEstimator entity that handle Reinforcement Learning (RL) execution within a SageMaker Training Job. input_data_s3_prefix (str): Input data path for the data source of the model training job. s3_model_output_path (str): Output data path of model artifact for the model training job. training_start_time (str): Starting timestamp of the model training job. training_end_time (str): Finished timestamp of the model training job. Returns: orchestrator.model_manager.ModelManager: A ``Model`` object associated with the given experiment. """ self.model_db_client = model_db_client self.experiment_id = experiment_id self.model_id = model_id # Currently we are not storing image/role and other model params in ModelDb self.image = image self.role = role self.instance_config = instance_config self.algor_config = algor_config # load configs self.instance_type = self.instance_config.get("instance_type", "local") self.instance_count = self.instance_config.get("instance_count", 1) self.algor_params = self.algor_config.get("algorithms_parameters", {}) # create a local ModelRecord object. self.model_record = ModelRecord( experiment_id, model_id, train_state, evaluation_job_name, eval_state, eval_scores, input_model_id, input_data_s3_prefix, manifest_file_path, eval_data_s3_path, s3_model_output_path, training_start_time, training_end_time, ) # try to save this record file. if it throws RecordAlreadyExistsException # reload the record from ModelDb, and recreate try: self.model_db_client.create_new_model_record( self.model_record.to_ddb_record()) except RecordAlreadyExistsException: logger.debug("Model already exists. Reloading from model record.") model_record = self.model_db_client.get_model_record( experiment_id, model_id) self.model_record = ModelRecord.load_from_ddb_record(model_record) except Exception as e: logger.error("Unhandled Exception! " + str(e)) raise UnhandledWorkflowException( "Something went wrong while creating a new model") if boto_session is None: boto_session = boto3.Session() self.boto_session = boto_session if self.instance_type == "local": self.sagemaker_session = LocalSession() else: self.sagemaker_session = sagemaker.session.Session( self.boto_session) self.sagemaker_client = self.sagemaker_session.sagemaker_client def _jsonify(self): """Return a JSON Dict with metadata of the ModelManager Object stored in self.model_record """ return self.model_record.to_ddb_record() @classmethod def name_next_model(cls, experiment_id): """Generate unique model id of a new model in the experiment Args: experiment_id (str): A unique id for the experiment. The created/loaded model will be associated with the given experiment. Returns: str: A unique id for a new model """ return experiment_id + "-model-id-" + str(int(time.time())) def _get_rl_estimator_args(self, eval=False): """Get required args to be used by RLEstimator class Args: eval (boolean): Boolean value to tell if the estimator is running a training/evaluation job. Return: dict: RLEstimator args used to trigger a SageMaker training job """ entry_point = "eval-cfa-vw.py" if eval else "train-vw.py" estimator_type = "Evaluation" if eval else "Training" job_types = "evaluation_jobs" if eval else "training_jobs" sagemaker_bucket = self.sagemaker_session.default_bucket() output_path = f"s3://{sagemaker_bucket}/{self.experiment_id}/{job_types}/" metric_definitions = [{ "Name": "average_loss", "Regex": "average loss = ([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?).*$" }] args = dict( entry_point=entry_point, source_dir="src", dependencies=["common/sagemaker_rl"], image_uri=self.image, role=self.role, sagemaker_session=self.sagemaker_session, instance_type=self.instance_type, instance_count=self.instance_count, metric_definitions=metric_definitions, hyperparameters=self.algor_params, output_path=output_path, code_location=output_path.strip("/"), ) if self.instance_type == "local": logger.info( f"{estimator_type} job will be executed in 'local' mode") else: logger.info( f"{estimator_type} job will be executed in 'SageMaker' mode") return args def _fit_first_model(self, input_data_s3_prefix=None, manifest_file_path=None, wait=False, logs=True): """ A Estimator fit() call to initiate the first model of the experiment """ rl_estimator_args = self._get_rl_estimator_args() self.rl_estimator = RLEstimator(**rl_estimator_args) if manifest_file_path: input_data = sagemaker.session.s3_input( s3_data=manifest_file_path, input_mode="File", s3_data_type="ManifestFile") self.rl_estimator.fit(job_name=self.model_id, inputs=input_data, wait=wait, logs=logs) else: self.rl_estimator.fit(job_name=self.model_id, inputs=input_data_s3_prefix, wait=wait, logs=logs) def fit(self, input_model_id=None, input_data_s3_prefix=None, manifest_file_path=None, wait=False, logs=True): """A Estimator fit() call to start a model training job. Args: input_model_id (str): Model id of model to used as pre-trained model of the training job input_data_s3_prefix (str): Defines the location of s3 data to train on. manifest_file_path (str): Manifest file used to provide training data. wait (bool): Whether the call should wait until the job completes. Only meaningful when running in SageMaker mode. logs (bool): Whether to show the logs produced by the job. Only meaningful when wait is True (default: True). """ # update object var, to be reflected in DDb Record as well. self.model_record.add_new_training_job_info( input_model_id=input_model_id, input_data_s3_prefix=input_data_s3_prefix, manifest_file_path=manifest_file_path, ) self.model_db_client.update_model_record(self._jsonify()) if input_model_id is None: self._fit_first_model(input_data_s3_prefix=input_data_s3_prefix, manifest_file_path=manifest_file_path, wait=wait, logs=logs) else: # use 'input_model_id' as pretrained model for training input_model_record = self.model_db_client.get_model_record( self.experiment_id, input_model_id) model_artifact_path = input_model_record.get( "s3_model_output_path") rl_estimator_args = self._get_rl_estimator_args() rl_estimator_args["model_channel_name"] = "pretrained_model" rl_estimator_args["model_uri"] = model_artifact_path self.rl_estimator = RLEstimator(**rl_estimator_args) if manifest_file_path: inputs = sagemaker.session.s3_input( s3_data=manifest_file_path, s3_data_type="ManifestFile") else: inputs = input_data_s3_prefix self.rl_estimator.fit(job_name=self.model_id, inputs=inputs, wait=wait, logs=logs) def evaluate( self, input_data_s3_prefix=None, manifest_file_path=None, evaluation_job_name=None, local_mode=True, wait=False, logs=True, ): """A Estimator fit() call to start a model evaluation job. Args: input_data_s3_prefix (str): Defines the location of s3 data used for evaluation manifest_file_path (str): Manifest file used to provide evaluation data. evaluation_job_name (str): Unique Sagemaker job name to identify the evaluation job local_mode (bool): Whether the evaluation job is running on local mode wait (bool): Whether the call should wait until the job completes. Only meaningful when running in SageMaker mode. logs (bool): Whether to show the logs produced by the job. Only meaningful when wait is True. """ # use self.model_id, self._s3_model_output_path as the model to evaluate # Model object has already been initialized with up-to-date DDb record. model_artifact_path = self.model_record.get_model_artifact_path() rl_estimator_args = self._get_rl_estimator_args(eval=True) rl_estimator_args["model_channel_name"] = "pretrained_model" rl_estimator_args["model_uri"] = model_artifact_path if manifest_file_path: inputs = sagemaker.session.s3_input(s3_data=manifest_file_path, s3_data_type="ManifestFile") if local_mode: rl_estimator_args["hyperparameters"].update( {"local_mode_manifest": True}) else: inputs = input_data_s3_prefix # (dict[str, str] or dict[str, sagemaker.session.s3_input]) for evaluation channel eval_channel_inputs = {EVAL_CHANNEL: inputs} self.rl_estimator = RLEstimator(**rl_estimator_args) # update to save eval_data_s3_path in DDb as well, or # update to read from SM describe call... maybe will not work in local mode but. eval_data_s3_path = manifest_file_path if ( manifest_file_path is not None) else input_data_s3_prefix # we keep eval job state as pending, before the SM job has been submitted. # the syncer function should update this state, based on SM job status. self.model_record.add_new_evaluation_job_info( evaluation_job_name=evaluation_job_name, eval_data_s3_path=eval_data_s3_path) self.model_db_client.update_model_record(self._jsonify()) # The following local variables (unsaved to DDb) make evaluation job non-resumable. self.log_output = None self.local_mode = local_mode if local_mode: # Capture eval score by regex expression # log should contain only one "average loss = some number" pattern with CaptureStdout() as log_output: self.rl_estimator.fit(job_name=evaluation_job_name, inputs=eval_channel_inputs, wait=wait, logs=logs) self.log_output = "\n".join(log_output) logger.debug(self.log_output) else: self.rl_estimator.fit(job_name=evaluation_job_name, inputs=eval_channel_inputs, wait=wait, logs=logs) def update_model_training_state(self): self._update_model_table_training_states() def update_model_evaluation_state(self): self._update_model_table_evaluation_states() def _update_model_table_training_states(self): """ Update the training states in the model table. This method will poll the Sagemaker training job and then update training job metadata of the model, including: train_state, s3_model_output_path, training_start_time, training_end_time Args: model_record (dict): Current model record in the model table """ if self.model_record.model_in_terminal_state(): # model already in one of the final states # need not do anything. self.model_db_client.update_model_record(self._jsonify()) return self._jsonify() # Else, try and fetch updated SageMaker TrainingJob status sm_job_info = {} max_describe_retries = 100 sleep_between_describe_retries = 10 for i in range(max_describe_retries): try: sm_job_info = self.sagemaker_client.describe_training_job( TrainingJobName=self.model_id) except Exception as e: if "ValidationException" in str(e): if i > max_describe_retries: # max attempts for DescribeTrainingJob. Fail with ValidationException logger.warn( f"Looks like SageMaker Job was not submitted successfully." f" Failing Training Job with ModelId {self.model_id}" ) self.model_record.update_model_as_failed() self.model_db_client.update_model_as_failed( self._jsonify()) return else: time.sleep(sleep_between_describe_retries) continue else: # Do not raise exception, most probably throttling. logger.warn( f"Failed to check SageMaker Training Job state for ModelId {self.model_id}." " This exception will be ignored, and retried.") logger.debug(e) time.sleep(sleep_between_describe_retries) return self._jsonify() train_state = sm_job_info.get("TrainingJobStatus", "Pending") training_start_time = sm_job_info.get("TrainingStartTime", None) training_end_time = sm_job_info.get("TrainingEndTime", None) if training_start_time is not None: training_start_time = training_start_time.strftime( "%Y-%m-%d %H:%M:%S") if training_end_time is not None: training_end_time = training_end_time.strftime("%Y-%m-%d %H:%M:%S") model_artifacts = sm_job_info.get("ModelArtifacts", None) if model_artifacts is not None: s3_model_output_path = model_artifacts.get("S3ModelArtifacts", None) else: s3_model_output_path = None self.model_record.update_model_job_status(training_start_time, training_end_time, train_state, s3_model_output_path) self.model_db_client.update_model_job_state(self._jsonify()) def _update_model_table_evaluation_states(self): """Update the evaluation states in the model table. This method will poll the Sagemaker evaluation job and then update evaluation job metadata of the model, including: eval_state, eval_scores Args: model_record (dict): Current model record in the model table """ if self.model_record.eval_in_terminal_state(): self.model_db_client.update_model_record(self._jsonify()) return self._jsonify() # Try and fetch updated SageMaker Training Job Status sm_eval_job_info = {} max_describe_retries = 100 sleep_between_describe_retries = 10 for i in range(max_describe_retries): try: sm_eval_job_info = self.sagemaker_client.describe_training_job( TrainingJobName=self.model_record._evaluation_job_name) except Exception as e: if "ValidationException" in str(e): print(e) if i > max_describe_retries: # 3rd attempt for DescribeTrainingJob with validation failure logger.warn( "Looks like SageMaker Job was not submitted successfully." f" Failing EvaluationJob {self.model_record._evaluation_job_name}" ) self.model_record.update_eval_job_as_failed() self.model_db_client.update_model_eval_as_failed( self._jsonify()) return else: time.sleep(sleep_between_describe_retries) continue else: # Do not raise exception, most probably throttling. logger.warn( "Failed to check SageMaker Training Job state for EvaluationJob: " f" {self.model_record._evaluation_job_name}. This exception will be ignored," " and retried.") time.sleep(sleep_between_describe_retries) return self._jsonify() eval_state = sm_eval_job_info.get("TrainingJobStatus", "Pending") if eval_state == "Completed": eval_score = "n.a." if self.local_mode: rgx = re.compile( "average loss = ([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?).*$", re.M) eval_score_rgx = rgx.findall(self.log_output) if len(eval_score_rgx) == 0: logger.warning("No eval score available from vw job log.") else: eval_score = eval_score_rgx[0][0] # [('eval_score', '')] else: attempts = 0 while eval_score == "n.a." and attempts < 4: try: metric_df = TrainingJobAnalytics( self.model_record._evaluation_job_name, ["average_loss"]).dataframe() eval_score = str(metric_df[metric_df["metric_name"] == "average_loss"]["value"][0]) except Exception: # to avoid throttling time.sleep(5) continue attempts += 1 self.model_record._eval_state = eval_state self.model_record.add_model_eval_scores(eval_score) self.model_db_client.update_model_eval_job_state(self._jsonify()) else: # update eval state via ddb client self.model_record.update_eval_job_state(eval_state) self.model_db_client.update_model_eval_job_state(self._jsonify())
class EstimatorBase(with_metaclass(ABCMeta, object)): """Handle end-to-end Amazon SageMaker training and deployment tasks. For introduction to model training and deployment, see http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html Subclasses must define a way to determine what image to use for training, what hyperparameters to use, and how to create an appropriate predictor instance. """ def __init__(self, role, train_instance_count, train_instance_type, train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File', output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None): """Initialize an ``EstimatorBase`` instance. Args: role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. train_instance_count (int): Number of Amazon EC2 instances to use for training. train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. train_volume_size (int): Size in GB of the EBS volume to use for storing input data during training (default: 30). Must be large enough to store training data if File Mode is used (which is the default). train_max_run (int): Timeout in seconds for training (default: 24 * 60 * 60). After this amount of time Amazon SageMaker terminates the job regardless of its current status. input_mode (str): The input mode that the algorithm supports (default: 'File'). Valid modes: 'File' - Amazon SageMaker copies the training dataset from the S3 location to a local directory. 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe. output_path (str): S3 location for saving the trainig result (model artifacts and output files). If not specified, results are stored to a default bucket. If the bucket with the specific name does not exist, the estimator creates the bucket during the :meth:`~sagemaker.estimator.EstimatorBase.fit` method execution. output_kms_key (str): Optional. KMS key ID for encrypting the training output (default: None). base_job_name (str): Prefix for training job name when the :meth:`~sagemaker.estimator.EstimatorBase.fit` method launches. If not specified, the estimator generates a default job name, based on the training image name and current timestamp. sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. """ self.role = role self.train_instance_count = train_instance_count self.train_instance_type = train_instance_type self.train_volume_size = train_volume_size self.train_max_run = train_max_run self.input_mode = input_mode if self.train_instance_type in ('local', 'local_gpu'): self.local_mode = True if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1: raise RuntimeError("Distributed Training in Local GPU is not supported") self.sagemaker_session = LocalSession() else: self.local_mode = False self.sagemaker_session = sagemaker_session or Session() self.base_job_name = base_job_name self._current_job_name = None self.output_path = output_path self.output_kms_key = output_kms_key self.latest_training_job = None @abstractmethod def train_image(self): """Return the Docker image to use for training. The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to find the image to use for model training. Returns: str: The URI of the Docker image. """ pass @abstractmethod def hyperparameters(self): """Return the hyperparameters as a dictionary to use for training. The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which trains the model, calls this method to find the hyperparameters. Returns: dict[str, str]: The hyperparameters. """ pass def fit(self, inputs, wait=True, logs=True, job_name=None): """Train a model using the input training dataset. The API calls the Amazon SageMaker CreateTrainingJob API to start model training. The API uses configuration you provided to create the estimator and the specified input training data to send the CreatingTrainingJob request to Amazon SageMaker. This is a synchronous operation. After the model training successfully completes, you can call the ``deploy()`` method to host the model using the Amazon SageMaker hosting services. Args: inputs (str or dict or sagemaker.session.s3_input): Information about the training data. This can be one of three types: * (str) the S3 location where training data is saved. * (dict[str, str] or dict[str, sagemaker.session.s3_input]) If using multiple channels for training data, you can specify a dict mapping channel names to strings or :func:`~sagemaker.session.s3_input` objects. * (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide additional information about the training dataset. See :func:`sagemaker.session.s3_input` for full details. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when wait is True (default: True). job_name (str): Training job name. If not specified, the estimator generates a default job name, based on the training image name and current timestamp. """ if job_name is not None: self._current_job_name = job_name else: # make sure the job name is unique for each invocation, honor supplied base_job_name or generate it base_name = self.base_job_name or base_name_from_image(self.train_image()) self._current_job_name = name_from_base(base_name) # if output_path was specified we use it otherwise initialize here if self.output_path is None: self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket()) self.latest_training_job = _TrainingJob.start_new(self, inputs) if wait: self.latest_training_job.wait(logs=logs) @classmethod def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session): """Create an Estimator from existing training job data. Args: init_params (dict): The init_params the training job was created with. hyperparameters (dict): The hyperparameters the training job was created with. image (str): Container image (if any) the training job was created with sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator. Returns: An instance of the calling Estimator Class. """ raise NotImplementedError() @classmethod def attach(cls, training_job_name, sagemaker_session=None, job_details=None): """Attach to an existing training job. Create an Estimator bound to an existing training job, each subclass is responsible to implement ``_prepare_init_params_from_job_description()`` as this method delegates the actual conversion of a training job description to the arguments that the class constructor expects. After attaching, if the training job has a Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``. If the training job is in progress, attach will block and display log messages from the training job, until the training job completes. Args: training_job_name (str): The name of the training job to attach to. sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. Examples: >>> my_estimator.fit(wait=False) >>> training_job_name = my_estimator.latest_training_job.name Later on: >>> attached_estimator = Estimator.attach(training_job_name) >>> attached_estimator.deploy() Returns: Instance of the calling ``Estimator`` Class with the attached training job. """ sagemaker_session = sagemaker_session or Session() job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name) init_params = cls._prepare_init_params_from_job_description(job_details) estimator = cls(sagemaker_session=sagemaker_session, **init_params) estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session, training_job_name=init_params['base_job_name']) estimator.latest_training_job.wait() return estimator def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kwargs): """Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object. More information: http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html Args: initial_instance_count (int): Minimum number of EC2 instances to deploy to an endpoint for prediction. instance_type (str): Type of EC2 instance to deploy to an endpoint for prediction, for example, 'ml.c4.xlarge'. endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of the training job is used. **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. Returns: sagemaker.predictor.RealTimePredictor: A predictor that provides a ``predict()`` method, which can be used to send requests to the Amazon SageMaker endpoint and obtain inferences. """ if not self.latest_training_job: raise RuntimeError('Estimator has not been fit yet.') endpoint_name = endpoint_name or self.latest_training_job.name self.deploy_instance_type = instance_type return self.create_model(**kwargs).deploy( instance_type=instance_type, initial_instance_count=initial_instance_count, endpoint_name=endpoint_name) @property def model_data(self): """str: The model location in S3. Only set if Estimator has been ``fit()``.""" return self.sagemaker_session.sagemaker_client.describe_training_job( TrainingJobName=self.latest_training_job.name)['ModelArtifacts']['S3ModelArtifacts'] @abstractmethod def create_model(self, **kwargs): """Create a SageMaker ``Model`` object that can be deployed to an ``Endpoint``. Args: **kwargs: Keyword arguments used by the implemented method for creating the ``Model``. Returns: sagemaker.model.Model: A SageMaker ``Model`` object. See :func:`~sagemaker.model.Model` for full details. """ pass @classmethod def _prepare_init_params_from_job_description(cls, job_details): """Convert the job description to init params that can be handled by the class constructor Args: job_details: the returned job details from a describe_training_job API call. Returns: dictionary: The transformed init_params """ init_params = dict() init_params['role'] = job_details['RoleArn'] init_params['train_instance_count'] = job_details['ResourceConfig']['InstanceCount'] init_params['train_instance_type'] = job_details['ResourceConfig']['InstanceType'] init_params['train_volume_size'] = job_details['ResourceConfig']['VolumeSizeInGB'] init_params['train_max_run'] = job_details['StoppingCondition']['MaxRuntimeInSeconds'] init_params['input_mode'] = job_details['AlgorithmSpecification']['TrainingInputMode'] init_params['base_job_name'] = job_details['TrainingJobName'] init_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath'] init_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId'] init_params['hyperparameters'] = job_details['HyperParameters'] init_params['image'] = job_details['AlgorithmSpecification']['TrainingImage'] return init_params def delete_endpoint(self): """Delete an Amazon SageMaker ``Endpoint``. Raises: ValueError: If the endpoint does not exist. """ if self.latest_training_job is None: raise ValueError('Endpoint was not created yet') self.sagemaker_session.delete_endpoint(self.latest_training_job.name)