def __init__(self, role, train_instance_count, train_instance_type,
                 train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
                 output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None):
        """Initialize an ``EstimatorBase`` instance.

        Args:
            role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs
                that create Amazon SageMaker endpoints use this role to access training data and model artifacts.
                After the endpoint is created, the inference code might use the IAM role,
                if it needs to access an AWS resource.
            train_instance_count (int): Number of Amazon EC2 instances to use for training.
            train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
            train_volume_size (int): Size in GB of the EBS volume to use for storing input data
                during training (default: 30). Must be large enough to store training data if File Mode is used
                (which is the default).
            train_max_run (int): Timeout in seconds for training (default: 24 * 60 * 60).
                After this amount of time Amazon SageMaker terminates the job regardless of its current status.
            input_mode (str): The input mode that the algorithm supports (default: 'File'). Valid modes:
                'File' - Amazon SageMaker copies the training dataset from the S3 location to a local directory.
                'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe.
            output_path (str): S3 location for saving the trainig result (model artifacts and output files).
                If not specified, results are stored to a default bucket. If the bucket with the specific name
                does not exist, the estimator creates the bucket during the
                :meth:`~sagemaker.estimator.EstimatorBase.fit` method execution.
            output_kms_key (str): Optional. KMS key ID for encrypting the training output (default: None).
            base_job_name (str): Prefix for training job name when the :meth:`~sagemaker.estimator.EstimatorBase.fit`
                method launches. If not specified, the estimator generates a default job name, based on
                the training image name and current timestamp.
            sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
                Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
                using the default AWS configuration chain.
        """
        self.role = role
        self.train_instance_count = train_instance_count
        self.train_instance_type = train_instance_type
        self.train_volume_size = train_volume_size
        self.train_max_run = train_max_run
        self.input_mode = input_mode

        if self.train_instance_type in ('local', 'local_gpu'):
            self.local_mode = True
            if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
                raise RuntimeError("Distributed Training in Local GPU is not supported")

            self.sagemaker_session = LocalSession()
        else:
            self.local_mode = False
            self.sagemaker_session = sagemaker_session or Session()

        self.base_job_name = base_job_name
        self._current_job_name = None
        self.output_path = output_path
        self.output_kms_key = output_kms_key
        self.latest_training_job = None
示例#2
0
    def __init__(self,
                 instance_type,
                 instance_count,
                 image,
                 sagemaker_session=None):
        """Initialize a SageMakerContainer instance

        It uses a :class:`sagemaker.session.Session` for general interaction with user configuration
        such as getting the default sagemaker S3 bucket. However this class does not call any of the
        SageMaker APIs.

        Args:
            instance_type (str): The instance type to use. Either 'local' or 'local_gpu'
            instance_count (int): The number of instances to create.
            image (str): docker image to use.
            sagemaker_session (sagemaker.session.Session): a sagemaker session to use when interacting
                with SageMaker.
        """
        from sagemaker.local.local_session import LocalSession
        self.sagemaker_session = sagemaker_session or LocalSession()
        self.instance_type = instance_type
        self.instance_count = instance_count
        self.image = image
        # Since we are using a single docker network, Generate a random suffix to attach to the container names.
        #  This way multiple jobs can run in parallel.
        suffix = ''.join(
            random.choice(string.ascii_uppercase + string.digits)
            for _ in range(5))
        self.hosts = [
            '{}-{}-{}'.format(CONTAINER_PREFIX, i, suffix)
            for i in range(1, self.instance_count + 1)
        ]
        self.container_root = None
        self.container = None
示例#3
0
    def __init__(
        self,
        instance_type,
        instance_count,
        image,
        sagemaker_session=None,
        container_entrypoint=None,
        container_arguments=None,
    ):
        """Initialize a SageMakerContainer instance

        It uses a :class:`sagemaker.session.Session` for general interaction
        with user configuration such as getting the default sagemaker S3 bucket.
        However this class does not call any of the SageMaker APIs.

        Args:
            instance_type (str): The instance type to use. Either 'local' or
                'local_gpu'
            instance_count (int): The number of instances to create.
            image (str): docker image to use.
            sagemaker_session (sagemaker.session.Session): a sagemaker session
                to use when interacting with SageMaker.
            container_entrypoint (str): the container entrypoint to execute
            container_arguments (str): the container entrypoint arguments
        """
        from sagemaker.local.local_session import LocalSession

        # check if docker-compose is installed
        if find_executable("docker-compose") is None:
            raise ImportError(
                "'docker-compose' is not installed. "
                "Local Mode features will not work without docker-compose. "
                "For more information on how to install 'docker-compose', please, see "
                "https://docs.docker.com/compose/install/")

        self.sagemaker_session = sagemaker_session or LocalSession()
        self.instance_type = instance_type
        self.instance_count = instance_count
        self.image = image
        self.container_entrypoint = container_entrypoint
        self.container_arguments = container_arguments
        # Since we are using a single docker network, Generate a random suffix to attach to the
        # container names. This way multiple jobs can run in parallel.
        suffix = "".join(
            random.choice(string.ascii_lowercase + string.digits)
            for _ in range(5))
        self.hosts = [
            "{}-{}-{}".format(CONTAINER_PREFIX, i, suffix)
            for i in range(1, self.instance_count + 1)
        ]
        self.container_root = None
        self.container = None
示例#4
0
    def __init__(
        self,
        model_db_client: ModelDbClient,
        experiment_id,
        model_id,
        image=None,
        role=None,
        instance_config={},
        boto_session=None,
        algor_config={},
        train_state=None,
        evaluation_job_name=None,
        eval_state=None,
        eval_scores={},
        input_model_id=None,
        rl_estimator=None,
        input_data_s3_prefix=None,
        manifest_file_path=None,
        eval_data_s3_path=None,
        s3_model_output_path=None,
        training_start_time=None,
        training_end_time=None,
    ):
        """Initialize a model entity in the current experiment

        Args:
            model_db_client (ModelDBClient): A DynamoDB client
                to query the model table. The 'Model' entity use this client
                to read/update the model state.
            experiment_id (str): A unique id for the experiment. The created/loaded
                model will be associated with the given experiment.
            model_id (str): Aa unique id for the model. The model table uses
                model id to manage associated model metadata.
            image (str): The container image to use for training/evaluation.
            role (str): An AWS IAM role (either name or full ARN). The Amazon
                SageMaker training jobs will use this role to access AWS resources.
            instance_config (dict): A dictionary that specify the resource
                configuration for the model training/evaluation job.
            boto_session (boto3.session.Session): A session stores configuration
                state and allows you to create service clients and resources.
            algor_config (dict): A dictionary that specify the algorithm type
                and hyper parameters of the training/evaluation job.
            train_state (str): State of the model training job.
            evaluation_job_name (str): Job name for Latest Evaluation Job for this model
            eval_state (str): State of the model evaluation job.
            input_model_id (str): A unique model id to specify which model to use
                as a pre-trained model for the model training job.
            rl_estimator (sagemaker.rl.estimator.RLEstimator): A Sagemaker RLEstimator
                entity that handle Reinforcement Learning (RL) execution within
                a SageMaker Training Job.
            input_data_s3_prefix (str): Input data path for the data source of the
                model training job.
            s3_model_output_path (str): Output data path of model artifact for the
                model training job.
            training_start_time (str): Starting timestamp of the model training job.
            training_end_time (str): Finished timestamp of the model training job.

        Returns:
            orchestrator.model_manager.ModelManager: A ``Model`` object associated
            with the given experiment.
        """

        self.model_db_client = model_db_client
        self.experiment_id = experiment_id
        self.model_id = model_id

        # Currently we are not storing image/role and other model params in ModelDb
        self.image = image
        self.role = role
        self.instance_config = instance_config
        self.algor_config = algor_config

        # load configs
        self.instance_type = self.instance_config.get("instance_type", "local")
        self.instance_count = self.instance_config.get("instance_count", 1)
        self.algor_params = self.algor_config.get("algorithms_parameters", {})

        # create a local ModelRecord object.
        self.model_record = ModelRecord(
            experiment_id,
            model_id,
            train_state,
            evaluation_job_name,
            eval_state,
            eval_scores,
            input_model_id,
            input_data_s3_prefix,
            manifest_file_path,
            eval_data_s3_path,
            s3_model_output_path,
            training_start_time,
            training_end_time,
        )

        # try to save this record file. if it throws RecordAlreadyExistsException
        # reload the record from ModelDb, and recreate
        try:
            self.model_db_client.create_new_model_record(
                self.model_record.to_ddb_record())
        except RecordAlreadyExistsException:
            logger.debug("Model already exists. Reloading from model record.")
            model_record = self.model_db_client.get_model_record(
                experiment_id, model_id)
            self.model_record = ModelRecord.load_from_ddb_record(model_record)
        except Exception as e:
            logger.error("Unhandled Exception! " + str(e))
            raise UnhandledWorkflowException(
                "Something went wrong while creating a new model")

        if boto_session is None:
            boto_session = boto3.Session()
        self.boto_session = boto_session

        if self.instance_type == "local":
            self.sagemaker_session = LocalSession()
        else:
            self.sagemaker_session = sagemaker.session.Session(
                self.boto_session)
        self.sagemaker_client = self.sagemaker_session.sagemaker_client
示例#5
0
class ModelManager:
    """A model entity with the given experiment. This class will handle
    the model creation, model training, model evaluation and model metadata
    management.
    """
    def __init__(
        self,
        model_db_client: ModelDbClient,
        experiment_id,
        model_id,
        image=None,
        role=None,
        instance_config={},
        boto_session=None,
        algor_config={},
        train_state=None,
        evaluation_job_name=None,
        eval_state=None,
        eval_scores={},
        input_model_id=None,
        rl_estimator=None,
        input_data_s3_prefix=None,
        manifest_file_path=None,
        eval_data_s3_path=None,
        s3_model_output_path=None,
        training_start_time=None,
        training_end_time=None,
    ):
        """Initialize a model entity in the current experiment

        Args:
            model_db_client (ModelDBClient): A DynamoDB client
                to query the model table. The 'Model' entity use this client
                to read/update the model state.
            experiment_id (str): A unique id for the experiment. The created/loaded
                model will be associated with the given experiment.
            model_id (str): Aa unique id for the model. The model table uses
                model id to manage associated model metadata.
            image (str): The container image to use for training/evaluation.
            role (str): An AWS IAM role (either name or full ARN). The Amazon
                SageMaker training jobs will use this role to access AWS resources.
            instance_config (dict): A dictionary that specify the resource
                configuration for the model training/evaluation job.
            boto_session (boto3.session.Session): A session stores configuration
                state and allows you to create service clients and resources.
            algor_config (dict): A dictionary that specify the algorithm type
                and hyper parameters of the training/evaluation job.
            train_state (str): State of the model training job.
            evaluation_job_name (str): Job name for Latest Evaluation Job for this model
            eval_state (str): State of the model evaluation job.
            input_model_id (str): A unique model id to specify which model to use
                as a pre-trained model for the model training job.
            rl_estimator (sagemaker.rl.estimator.RLEstimator): A Sagemaker RLEstimator
                entity that handle Reinforcement Learning (RL) execution within
                a SageMaker Training Job.
            input_data_s3_prefix (str): Input data path for the data source of the
                model training job.
            s3_model_output_path (str): Output data path of model artifact for the
                model training job.
            training_start_time (str): Starting timestamp of the model training job.
            training_end_time (str): Finished timestamp of the model training job.

        Returns:
            orchestrator.model_manager.ModelManager: A ``Model`` object associated
            with the given experiment.
        """

        self.model_db_client = model_db_client
        self.experiment_id = experiment_id
        self.model_id = model_id

        # Currently we are not storing image/role and other model params in ModelDb
        self.image = image
        self.role = role
        self.instance_config = instance_config
        self.algor_config = algor_config

        # load configs
        self.instance_type = self.instance_config.get("instance_type", "local")
        self.instance_count = self.instance_config.get("instance_count", 1)
        self.algor_params = self.algor_config.get("algorithms_parameters", {})

        # create a local ModelRecord object.
        self.model_record = ModelRecord(
            experiment_id,
            model_id,
            train_state,
            evaluation_job_name,
            eval_state,
            eval_scores,
            input_model_id,
            input_data_s3_prefix,
            manifest_file_path,
            eval_data_s3_path,
            s3_model_output_path,
            training_start_time,
            training_end_time,
        )

        # try to save this record file. if it throws RecordAlreadyExistsException
        # reload the record from ModelDb, and recreate
        try:
            self.model_db_client.create_new_model_record(
                self.model_record.to_ddb_record())
        except RecordAlreadyExistsException:
            logger.debug("Model already exists. Reloading from model record.")
            model_record = self.model_db_client.get_model_record(
                experiment_id, model_id)
            self.model_record = ModelRecord.load_from_ddb_record(model_record)
        except Exception as e:
            logger.error("Unhandled Exception! " + str(e))
            raise UnhandledWorkflowException(
                "Something went wrong while creating a new model")

        if boto_session is None:
            boto_session = boto3.Session()
        self.boto_session = boto_session

        if self.instance_type == "local":
            self.sagemaker_session = LocalSession()
        else:
            self.sagemaker_session = sagemaker.session.Session(
                self.boto_session)
        self.sagemaker_client = self.sagemaker_session.sagemaker_client

    def _jsonify(self):
        """Return a JSON Dict with metadata of the ModelManager Object stored in
        self.model_record
        """
        return self.model_record.to_ddb_record()

    @classmethod
    def name_next_model(cls, experiment_id):
        """Generate unique model id of a new model in the experiment

        Args:
            experiment_id (str): A unique id for the experiment. The created/loaded
                model will be associated with the given experiment.

        Returns:
            str: A unique id for a new model
        """
        return experiment_id + "-model-id-" + str(int(time.time()))

    def _get_rl_estimator_args(self, eval=False):
        """Get required args to be used by RLEstimator class

        Args:
            eval (boolean): Boolean value to tell if the estimator is
                running a training/evaluation job.

        Return:
            dict: RLEstimator args used to trigger a SageMaker training job
        """
        entry_point = "eval-cfa-vw.py" if eval else "train-vw.py"
        estimator_type = "Evaluation" if eval else "Training"
        job_types = "evaluation_jobs" if eval else "training_jobs"

        sagemaker_bucket = self.sagemaker_session.default_bucket()
        output_path = f"s3://{sagemaker_bucket}/{self.experiment_id}/{job_types}/"

        metric_definitions = [{
            "Name":
            "average_loss",
            "Regex":
            "average loss = ([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?).*$"
        }]

        args = dict(
            entry_point=entry_point,
            source_dir="src",
            dependencies=["common/sagemaker_rl"],
            image_uri=self.image,
            role=self.role,
            sagemaker_session=self.sagemaker_session,
            instance_type=self.instance_type,
            instance_count=self.instance_count,
            metric_definitions=metric_definitions,
            hyperparameters=self.algor_params,
            output_path=output_path,
            code_location=output_path.strip("/"),
        )

        if self.instance_type == "local":
            logger.info(
                f"{estimator_type} job will be executed in 'local' mode")
        else:
            logger.info(
                f"{estimator_type} job will be executed in 'SageMaker' mode")
        return args

    def _fit_first_model(self,
                         input_data_s3_prefix=None,
                         manifest_file_path=None,
                         wait=False,
                         logs=True):
        """
        A Estimator fit() call to initiate the first model of the experiment
        """

        rl_estimator_args = self._get_rl_estimator_args()
        self.rl_estimator = RLEstimator(**rl_estimator_args)

        if manifest_file_path:
            input_data = sagemaker.session.s3_input(
                s3_data=manifest_file_path,
                input_mode="File",
                s3_data_type="ManifestFile")
            self.rl_estimator.fit(job_name=self.model_id,
                                  inputs=input_data,
                                  wait=wait,
                                  logs=logs)
        else:
            self.rl_estimator.fit(job_name=self.model_id,
                                  inputs=input_data_s3_prefix,
                                  wait=wait,
                                  logs=logs)

    def fit(self,
            input_model_id=None,
            input_data_s3_prefix=None,
            manifest_file_path=None,
            wait=False,
            logs=True):
        """A Estimator fit() call to start a model training job.

        Args:
            input_model_id (str): Model id of model to used as pre-trained model of the training job
            input_data_s3_prefix (str): Defines the location of s3 data to train on.
            manifest_file_path (str): Manifest file used to provide training data.
            wait (bool): Whether the call should wait until the job completes. Only
                meaningful when running in SageMaker mode.
            logs (bool): Whether to show the logs produced by the job.
                Only meaningful when wait is True (default: True).
        """
        # update object var, to be reflected in DDb Record as well.
        self.model_record.add_new_training_job_info(
            input_model_id=input_model_id,
            input_data_s3_prefix=input_data_s3_prefix,
            manifest_file_path=manifest_file_path,
        )
        self.model_db_client.update_model_record(self._jsonify())

        if input_model_id is None:
            self._fit_first_model(input_data_s3_prefix=input_data_s3_prefix,
                                  manifest_file_path=manifest_file_path,
                                  wait=wait,
                                  logs=logs)
        else:
            # use 'input_model_id' as pretrained model for training
            input_model_record = self.model_db_client.get_model_record(
                self.experiment_id, input_model_id)
            model_artifact_path = input_model_record.get(
                "s3_model_output_path")
            rl_estimator_args = self._get_rl_estimator_args()
            rl_estimator_args["model_channel_name"] = "pretrained_model"
            rl_estimator_args["model_uri"] = model_artifact_path
            self.rl_estimator = RLEstimator(**rl_estimator_args)

            if manifest_file_path:
                inputs = sagemaker.session.s3_input(
                    s3_data=manifest_file_path, s3_data_type="ManifestFile")
            else:
                inputs = input_data_s3_prefix

            self.rl_estimator.fit(job_name=self.model_id,
                                  inputs=inputs,
                                  wait=wait,
                                  logs=logs)

    def evaluate(
        self,
        input_data_s3_prefix=None,
        manifest_file_path=None,
        evaluation_job_name=None,
        local_mode=True,
        wait=False,
        logs=True,
    ):
        """A Estimator fit() call to start a model evaluation job.

        Args:
            input_data_s3_prefix (str): Defines the location of s3 data used for evaluation
            manifest_file_path (str): Manifest file used to provide evaluation data.
            evaluation_job_name (str): Unique Sagemaker job name to identify the evaluation job
            local_mode (bool): Whether the evaluation job is running on local mode
            wait (bool): Whether the call should wait until the job completes. Only
                meaningful when running in SageMaker mode.
            logs (bool): Whether to show the logs produced by the job.
                Only meaningful when wait is True.
        """
        # use self.model_id, self._s3_model_output_path as the model to evaluate
        # Model object has already been initialized with up-to-date DDb record.
        model_artifact_path = self.model_record.get_model_artifact_path()
        rl_estimator_args = self._get_rl_estimator_args(eval=True)
        rl_estimator_args["model_channel_name"] = "pretrained_model"
        rl_estimator_args["model_uri"] = model_artifact_path

        if manifest_file_path:
            inputs = sagemaker.session.s3_input(s3_data=manifest_file_path,
                                                s3_data_type="ManifestFile")
            if local_mode:
                rl_estimator_args["hyperparameters"].update(
                    {"local_mode_manifest": True})

        else:
            inputs = input_data_s3_prefix

        # (dict[str, str] or dict[str, sagemaker.session.s3_input]) for evaluation channel
        eval_channel_inputs = {EVAL_CHANNEL: inputs}
        self.rl_estimator = RLEstimator(**rl_estimator_args)

        # update to save eval_data_s3_path in DDb as well, or
        # update to read from SM describe call... maybe will not work in local mode but.
        eval_data_s3_path = manifest_file_path if (
            manifest_file_path is not None) else input_data_s3_prefix

        # we keep eval job state as pending, before the SM job has been submitted.
        # the syncer function should update this state, based on SM job status.
        self.model_record.add_new_evaluation_job_info(
            evaluation_job_name=evaluation_job_name,
            eval_data_s3_path=eval_data_s3_path)
        self.model_db_client.update_model_record(self._jsonify())

        # The following local variables (unsaved to DDb) make evaluation job non-resumable.
        self.log_output = None
        self.local_mode = local_mode

        if local_mode:
            # Capture eval score by regex expression
            # log should contain only one "average loss = some number" pattern
            with CaptureStdout() as log_output:
                self.rl_estimator.fit(job_name=evaluation_job_name,
                                      inputs=eval_channel_inputs,
                                      wait=wait,
                                      logs=logs)

            self.log_output = "\n".join(log_output)
            logger.debug(self.log_output)
        else:
            self.rl_estimator.fit(job_name=evaluation_job_name,
                                  inputs=eval_channel_inputs,
                                  wait=wait,
                                  logs=logs)

    def update_model_training_state(self):
        self._update_model_table_training_states()

    def update_model_evaluation_state(self):
        self._update_model_table_evaluation_states()

    def _update_model_table_training_states(self):
        """
        Update the training states in the model table. This method
        will poll the Sagemaker training job and then update
        training job metadata of the model, including:
            train_state,
            s3_model_output_path,
            training_start_time,
            training_end_time

        Args:
            model_record (dict): Current model record in the
                model table
        """
        if self.model_record.model_in_terminal_state():
            # model already in one of the final states
            # need not do anything.
            self.model_db_client.update_model_record(self._jsonify())
            return self._jsonify()

        # Else, try and fetch updated SageMaker TrainingJob status
        sm_job_info = {}

        max_describe_retries = 100
        sleep_between_describe_retries = 10

        for i in range(max_describe_retries):
            try:
                sm_job_info = self.sagemaker_client.describe_training_job(
                    TrainingJobName=self.model_id)
            except Exception as e:
                if "ValidationException" in str(e):
                    if i > max_describe_retries:
                        # max attempts for DescribeTrainingJob.  Fail with ValidationException
                        logger.warn(
                            f"Looks like SageMaker Job was not submitted successfully."
                            f" Failing Training Job with ModelId {self.model_id}"
                        )
                        self.model_record.update_model_as_failed()
                        self.model_db_client.update_model_as_failed(
                            self._jsonify())
                        return
                    else:
                        time.sleep(sleep_between_describe_retries)
                        continue
                else:
                    # Do not raise exception, most probably throttling.
                    logger.warn(
                        f"Failed to check SageMaker Training Job state for ModelId {self.model_id}."
                        " This exception will be ignored, and retried.")
                    logger.debug(e)
                    time.sleep(sleep_between_describe_retries)
                    return self._jsonify()

        train_state = sm_job_info.get("TrainingJobStatus", "Pending")
        training_start_time = sm_job_info.get("TrainingStartTime", None)
        training_end_time = sm_job_info.get("TrainingEndTime", None)

        if training_start_time is not None:
            training_start_time = training_start_time.strftime(
                "%Y-%m-%d %H:%M:%S")
        if training_end_time is not None:
            training_end_time = training_end_time.strftime("%Y-%m-%d %H:%M:%S")

        model_artifacts = sm_job_info.get("ModelArtifacts", None)
        if model_artifacts is not None:
            s3_model_output_path = model_artifacts.get("S3ModelArtifacts",
                                                       None)
        else:
            s3_model_output_path = None

        self.model_record.update_model_job_status(training_start_time,
                                                  training_end_time,
                                                  train_state,
                                                  s3_model_output_path)

        self.model_db_client.update_model_job_state(self._jsonify())

    def _update_model_table_evaluation_states(self):
        """Update the evaluation states in the model table. This method
        will poll the Sagemaker evaluation job and then update
        evaluation job metadata of the model, including:
            eval_state,
            eval_scores

        Args:
            model_record (dict): Current model record in the
                model table
        """

        if self.model_record.eval_in_terminal_state():
            self.model_db_client.update_model_record(self._jsonify())
            return self._jsonify()

        # Try and fetch updated SageMaker Training Job Status
        sm_eval_job_info = {}

        max_describe_retries = 100
        sleep_between_describe_retries = 10

        for i in range(max_describe_retries):
            try:
                sm_eval_job_info = self.sagemaker_client.describe_training_job(
                    TrainingJobName=self.model_record._evaluation_job_name)
            except Exception as e:
                if "ValidationException" in str(e):
                    print(e)
                    if i > max_describe_retries:
                        # 3rd attempt for DescribeTrainingJob with validation failure
                        logger.warn(
                            "Looks like SageMaker Job was not submitted successfully."
                            f" Failing EvaluationJob {self.model_record._evaluation_job_name}"
                        )
                        self.model_record.update_eval_job_as_failed()
                        self.model_db_client.update_model_eval_as_failed(
                            self._jsonify())
                        return
                    else:
                        time.sleep(sleep_between_describe_retries)
                        continue
                else:
                    # Do not raise exception, most probably throttling.
                    logger.warn(
                        "Failed to check SageMaker Training Job state for EvaluationJob: "
                        f" {self.model_record._evaluation_job_name}. This exception will be ignored,"
                        " and retried.")
                    time.sleep(sleep_between_describe_retries)
                    return self._jsonify()

        eval_state = sm_eval_job_info.get("TrainingJobStatus", "Pending")
        if eval_state == "Completed":
            eval_score = "n.a."

            if self.local_mode:
                rgx = re.compile(
                    "average loss = ([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?).*$",
                    re.M)
                eval_score_rgx = rgx.findall(self.log_output)

                if len(eval_score_rgx) == 0:
                    logger.warning("No eval score available from vw job log.")
                else:
                    eval_score = eval_score_rgx[0][0]  # [('eval_score', '')]
            else:
                attempts = 0
                while eval_score == "n.a." and attempts < 4:
                    try:
                        metric_df = TrainingJobAnalytics(
                            self.model_record._evaluation_job_name,
                            ["average_loss"]).dataframe()
                        eval_score = str(metric_df[metric_df["metric_name"] ==
                                                   "average_loss"]["value"][0])
                    except Exception:
                        # to avoid throttling
                        time.sleep(5)
                        continue
                    attempts += 1
            self.model_record._eval_state = eval_state
            self.model_record.add_model_eval_scores(eval_score)
            self.model_db_client.update_model_eval_job_state(self._jsonify())
        else:
            # update eval state via ddb client
            self.model_record.update_eval_job_state(eval_state)
            self.model_db_client.update_model_eval_job_state(self._jsonify())
class EstimatorBase(with_metaclass(ABCMeta, object)):
    """Handle end-to-end Amazon SageMaker training and deployment tasks.

    For introduction to model training and deployment, see
    http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html

    Subclasses must define a way to determine what image to use for training,
    what hyperparameters to use, and how to create an appropriate predictor instance.
    """

    def __init__(self, role, train_instance_count, train_instance_type,
                 train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
                 output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None):
        """Initialize an ``EstimatorBase`` instance.

        Args:
            role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs
                that create Amazon SageMaker endpoints use this role to access training data and model artifacts.
                After the endpoint is created, the inference code might use the IAM role,
                if it needs to access an AWS resource.
            train_instance_count (int): Number of Amazon EC2 instances to use for training.
            train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
            train_volume_size (int): Size in GB of the EBS volume to use for storing input data
                during training (default: 30). Must be large enough to store training data if File Mode is used
                (which is the default).
            train_max_run (int): Timeout in seconds for training (default: 24 * 60 * 60).
                After this amount of time Amazon SageMaker terminates the job regardless of its current status.
            input_mode (str): The input mode that the algorithm supports (default: 'File'). Valid modes:
                'File' - Amazon SageMaker copies the training dataset from the S3 location to a local directory.
                'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe.
            output_path (str): S3 location for saving the trainig result (model artifacts and output files).
                If not specified, results are stored to a default bucket. If the bucket with the specific name
                does not exist, the estimator creates the bucket during the
                :meth:`~sagemaker.estimator.EstimatorBase.fit` method execution.
            output_kms_key (str): Optional. KMS key ID for encrypting the training output (default: None).
            base_job_name (str): Prefix for training job name when the :meth:`~sagemaker.estimator.EstimatorBase.fit`
                method launches. If not specified, the estimator generates a default job name, based on
                the training image name and current timestamp.
            sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
                Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
                using the default AWS configuration chain.
        """
        self.role = role
        self.train_instance_count = train_instance_count
        self.train_instance_type = train_instance_type
        self.train_volume_size = train_volume_size
        self.train_max_run = train_max_run
        self.input_mode = input_mode

        if self.train_instance_type in ('local', 'local_gpu'):
            self.local_mode = True
            if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
                raise RuntimeError("Distributed Training in Local GPU is not supported")

            self.sagemaker_session = LocalSession()
        else:
            self.local_mode = False
            self.sagemaker_session = sagemaker_session or Session()

        self.base_job_name = base_job_name
        self._current_job_name = None
        self.output_path = output_path
        self.output_kms_key = output_kms_key
        self.latest_training_job = None

    @abstractmethod
    def train_image(self):
        """Return the Docker image to use for training.

        The  :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does the model training, calls this method to
        find the image to use for model training.

        Returns:
            str: The URI of the Docker image.
        """
        pass

    @abstractmethod
    def hyperparameters(self):
        """Return the hyperparameters as a dictionary to use for training.

        The  :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which trains the model, calls this method to
        find the hyperparameters.

        Returns:
            dict[str, str]: The hyperparameters.
        """
        pass

    def fit(self, inputs, wait=True, logs=True, job_name=None):
        """Train a model using the input training dataset.

        The API calls the Amazon SageMaker CreateTrainingJob API to start model training.
        The API uses configuration you provided to create the estimator and the
        specified input training data to send the CreatingTrainingJob request to Amazon SageMaker.

        This is a synchronous operation. After the model training successfully completes,
        you can call the ``deploy()`` method to host the model using the Amazon SageMaker hosting services.

        Args:
            inputs (str or dict or sagemaker.session.s3_input): Information about the training data.
                This can be one of three types:

                * (str) the S3 location where training data is saved.

                * (dict[str, str] or dict[str, sagemaker.session.s3_input]) If using multiple channels for
                    training data, you can specify a dict mapping channel names
                    to strings or :func:`~sagemaker.session.s3_input` objects.
                * (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide
                    additional information about the training dataset. See :func:`sagemaker.session.s3_input`
                    for full details.
            wait (bool): Whether the call should wait until the job completes (default: True).
            logs (bool): Whether to show the logs produced by the job.
                Only meaningful when wait is True (default: True).
            job_name (str): Training job name. If not specified, the estimator generates a default job name,
                based on the training image name and current timestamp.
        """

        if job_name is not None:
            self._current_job_name = job_name
        else:
            # make sure the job name is unique for each invocation, honor supplied base_job_name or generate it
            base_name = self.base_job_name or base_name_from_image(self.train_image())
            self._current_job_name = name_from_base(base_name)

        # if output_path was specified we use it otherwise initialize here
        if self.output_path is None:
            self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())

        self.latest_training_job = _TrainingJob.start_new(self, inputs)
        if wait:
            self.latest_training_job.wait(logs=logs)

    @classmethod
    def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
        """Create an Estimator from existing training job data.

        Args:
            init_params (dict): The init_params the training job was created with.
            hyperparameters (dict):  The hyperparameters the training job was created with.
            image (str): Container image (if any) the training job was created with
            sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.

        Returns: An instance of the calling Estimator Class.

        """
        raise NotImplementedError()

    @classmethod
    def attach(cls, training_job_name, sagemaker_session=None, job_details=None):
        """Attach to an existing training job.

        Create an Estimator bound to an existing training job, each subclass is responsible to implement
        ``_prepare_init_params_from_job_description()`` as this method delegates the actual conversion of a training
        job description to the arguments that the class constructor expects. After attaching, if the training job has a
        Complete status, it can be ``deploy()`` ed to create a SageMaker Endpoint and return a ``Predictor``.

        If the training job is in progress, attach will block and display log messages
        from the training job, until the training job completes.

        Args:
            training_job_name (str): The name of the training job to attach to.
            sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
                Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
                using the default AWS configuration chain.

        Examples:
            >>> my_estimator.fit(wait=False)
            >>> training_job_name = my_estimator.latest_training_job.name
            Later on:
            >>> attached_estimator = Estimator.attach(training_job_name)
            >>> attached_estimator.deploy()

        Returns:
            Instance of the calling ``Estimator`` Class with the attached training job.
        """
        sagemaker_session = sagemaker_session or Session()

        job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
        init_params = cls._prepare_init_params_from_job_description(job_details)

        estimator = cls(sagemaker_session=sagemaker_session, **init_params)
        estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,
                                                     training_job_name=init_params['base_job_name'])
        estimator.latest_training_job.wait()
        return estimator

    def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kwargs):
        """Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.

        More information:
        http://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html

        Args:
            initial_instance_count (int): Minimum number of EC2 instances to deploy to an endpoint for prediction.
            instance_type (str): Type of EC2 instance to deploy to an endpoint for prediction,
                for example, 'ml.c4.xlarge'.
            endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of
                the training job is used.
            **kwargs: Passed to invocation of ``create_model()``. Implementations may customize
                ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
                For more, see the implementation docs.

        Returns:
            sagemaker.predictor.RealTimePredictor: A predictor that provides a ``predict()`` method,
                which can be used to send requests to the Amazon SageMaker endpoint and obtain inferences.
        """
        if not self.latest_training_job:
            raise RuntimeError('Estimator has not been fit yet.')
        endpoint_name = endpoint_name or self.latest_training_job.name
        self.deploy_instance_type = instance_type
        return self.create_model(**kwargs).deploy(
            instance_type=instance_type,
            initial_instance_count=initial_instance_count,
            endpoint_name=endpoint_name)

    @property
    def model_data(self):
        """str: The model location in S3. Only set if Estimator has been ``fit()``."""
        return self.sagemaker_session.sagemaker_client.describe_training_job(
            TrainingJobName=self.latest_training_job.name)['ModelArtifacts']['S3ModelArtifacts']

    @abstractmethod
    def create_model(self, **kwargs):
        """Create a SageMaker ``Model`` object that can be deployed to an ``Endpoint``.

        Args:
            **kwargs: Keyword arguments used by the implemented method for creating the ``Model``.

        Returns:
            sagemaker.model.Model: A SageMaker ``Model`` object. See :func:`~sagemaker.model.Model` for full details.
        """
        pass

    @classmethod
    def _prepare_init_params_from_job_description(cls, job_details):
        """Convert the job description to init params that can be handled by the class constructor

        Args:
            job_details: the returned job details from a describe_training_job API call.

        Returns:
             dictionary: The transformed init_params

        """
        init_params = dict()

        init_params['role'] = job_details['RoleArn']
        init_params['train_instance_count'] = job_details['ResourceConfig']['InstanceCount']
        init_params['train_instance_type'] = job_details['ResourceConfig']['InstanceType']
        init_params['train_volume_size'] = job_details['ResourceConfig']['VolumeSizeInGB']
        init_params['train_max_run'] = job_details['StoppingCondition']['MaxRuntimeInSeconds']
        init_params['input_mode'] = job_details['AlgorithmSpecification']['TrainingInputMode']
        init_params['base_job_name'] = job_details['TrainingJobName']
        init_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath']
        init_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId']

        init_params['hyperparameters'] = job_details['HyperParameters']
        init_params['image'] = job_details['AlgorithmSpecification']['TrainingImage']

        return init_params

    def delete_endpoint(self):
        """Delete an Amazon SageMaker ``Endpoint``.

        Raises:
            ValueError: If the endpoint does not exist.
        """
        if self.latest_training_job is None:
            raise ValueError('Endpoint was not created yet')
        self.sagemaker_session.delete_endpoint(self.latest_training_job.name)