def __init__(
        self,
        model_db_client: ModelDbClient,
        experiment_id,
        model_id,
        image=None,
        role=None,
        instance_config={},
        boto_session=None,
        algor_config={},
        train_state=None,
        evaluation_job_name=None,
        eval_state=None,
        eval_scores={},
        input_model_id=None,
        rl_estimator=None,
        input_data_s3_prefix=None,
        manifest_file_path=None,
        eval_data_s3_path=None,
        s3_model_output_path=None,
        training_start_time=None,
        training_end_time=None,
    ):
        """Initialize a model entity in the current experiment

        Args:
            model_db_client (ModelDBClient): A DynamoDB client
                to query the model table. The 'Model' entity use this client
                to read/update the model state.
            experiment_id (str): A unique id for the experiment. The created/loaded
                model will be associated with the given experiment.
            model_id (str): Aa unique id for the model. The model table uses
                model id to manage associated model metadata.
            image (str): The container image to use for training/evaluation.
            role (str): An AWS IAM role (either name or full ARN). The Amazon
                SageMaker training jobs will use this role to access AWS resources.
            instance_config (dict): A dictionary that specify the resource
                configuration for the model training/evaluation job.
            boto_session (boto3.session.Session): A session stores configuration
                state and allows you to create service clients and resources.
            algor_config (dict): A dictionary that specify the algorithm type
                and hyper parameters of the training/evaluation job.
            train_state (str): State of the model training job.
            evaluation_job_name (str): Job name for Latest Evaluation Job for this model
            eval_state (str): State of the model evaluation job.
            input_model_id (str): A unique model id to specify which model to use
                as a pre-trained model for the model training job.
            rl_estimator (sagemaker.rl.estimator.RLEstimator): A Sagemaker RLEstimator
                entity that handle Reinforcement Learning (RL) execution within
                a SageMaker Training Job.
            input_data_s3_prefix (str): Input data path for the data source of the
                model training job.
            s3_model_output_path (str): Output data path of model artifact for the
                model training job.
            training_start_time (str): Starting timestamp of the model training job.
            training_end_time (str): Finished timestamp of the model training job.

        Returns:
            orchestrator.model_manager.ModelManager: A ``Model`` object associated
            with the given experiment.
        """

        self.model_db_client = model_db_client
        self.experiment_id = experiment_id
        self.model_id = model_id

        # Currently we are not storing image/role and other model params in ModelDb
        self.image = image
        self.role = role
        self.instance_config = instance_config
        self.algor_config = algor_config

        # load configs
        self.instance_type = self.instance_config.get("instance_type", "local")
        self.instance_count = self.instance_config.get("instance_count", 1)
        self.algor_params = self.algor_config.get("algorithms_parameters", {})

        # create a local ModelRecord object.
        self.model_record = ModelRecord(
            experiment_id,
            model_id,
            train_state,
            evaluation_job_name,
            eval_state,
            eval_scores,
            input_model_id,
            input_data_s3_prefix,
            manifest_file_path,
            eval_data_s3_path,
            s3_model_output_path,
            training_start_time,
            training_end_time,
        )

        # try to save this record file. if it throws RecordAlreadyExistsException
        # reload the record from ModelDb, and recreate
        try:
            self.model_db_client.create_new_model_record(
                self.model_record.to_ddb_record())
        except RecordAlreadyExistsException:
            logger.debug("Model already exists. Reloading from model record.")
            model_record = self.model_db_client.get_model_record(
                experiment_id, model_id)
            self.model_record = ModelRecord.load_from_ddb_record(model_record)
        except Exception as e:
            logger.error("Unhandled Exception! " + str(e))
            raise UnhandledWorkflowException(
                "Something went wrong while creating a new model")

        if boto_session is None:
            boto_session = boto3.Session()
        self.boto_session = boto_session

        if self.instance_type == "local":
            self.sagemaker_session = LocalSession()
        else:
            self.sagemaker_session = sagemaker.session.Session(
                self.boto_session)
        self.sagemaker_client = self.sagemaker_session.sagemaker_client
Beispiel #2
0
    def __init__(self,
                 join_db_client: JoinDbClient,
                 experiment_id,
                 join_job_id,
                 current_state=None,
                 input_obs_data_s3_path=None,
                 obs_start_time=None,
                 obs_end_time=None,
                 input_reward_data_s3_path=None,
                 output_joined_train_data_s3_path=None,
                 output_joined_eval_data_s3_path=None,
                 join_query_ids=[],
                 boto_session=None):
        """Initialize a joining job entity in the current experiment

        Args:
            join_db_client (JoinDbClient): A DynamoDB client
                to query the joining job table. The 'JoinJob' entity use this client
                to read/update the job state.
            experiment_id (str): A unique id for the experiment. The created/loaded
                joining job will be associated with the experiment.
            join_job_id (str): Aa unique id for the join job. The join job table uses
                join_job_id to manage associated job metadata.
            current_state (str): Current state of the joining job
            input_obs_data_s3_path (str): Input S3 data path for observation data
            obs_start_time (datetime): Datetime object to specify starting time of the
                observation data
            obs_end_time (datetime): Datetime object to specify ending time of the
                observation data
            input_reward_data_s3_path (str): S3 data path for rewards data
            output_joined_train_data_s3_path (str): Output S3 data path for training data split
            output_joined_eval_data_s3_path (str): Output S3 data path for evaluation data split
            join_query_ids (str): Athena join query ids for the joining requests
            boto_session (boto3.session.Session): A session stores configuration
                state and allows you to create service clients and resources.

        Return:
            orchestrator.join_manager.JoinManager: A ``JoinJob`` object associated
            with the given experiment.
        """

        self.join_db_client = join_db_client
        self.experiment_id = experiment_id
        self.join_job_id = join_job_id

        if boto_session is None:
            boto_session = boto3.Session()
        self.boto_session = boto_session

        # formatted athena table name
        self.obs_table_partitioned = self._formatted_table_name(
            f"obs-{experiment_id}-partitioned")
        self.obs_table_non_partitioned = self._formatted_table_name(
            f"obs-{experiment_id}")
        self.rewards_table = self._formatted_table_name(
            f"rewards-{experiment_id}")

        self.query_s3_output_bucket = self._create_athena_s3_bucket_if_not_exist(
        )
        self.athena_client = self.boto_session.client("athena")

        # create a local JoinJobRecord object.
        self.join_job_record = JoinJobRecord(
            experiment_id, join_job_id, current_state, input_obs_data_s3_path,
            obs_start_time, obs_end_time, input_reward_data_s3_path,
            output_joined_train_data_s3_path, output_joined_eval_data_s3_path,
            join_query_ids)

        # create obs partitioned/non-partitioned table if not exists
        if input_obs_data_s3_path and input_obs_data_s3_path != "local-join-does-not-apply":
            self._create_obs_table_if_not_exist()
        # create reward table if not exists
        if input_reward_data_s3_path and input_reward_data_s3_path != "local-join-does-not-apply":
            self._create_rewards_table_if_not_exist()
        # add partitions if input_obs_time_window is not None
        if obs_start_time and obs_end_time:
            self._add_time_partitions(obs_start_time, obs_end_time)

        # try to save this record file. if it throws RecordAlreadyExistsException
        # reload the record from JoinJobDb, and recreate
        try:
            self.join_db_client.create_new_join_job_record(
                self.join_job_record.to_ddb_record())
        except RecordAlreadyExistsException:
            logger.debug(
                "Join job already exists. Reloading from join job record.")
            join_job_record = self.join_db_client.get_join_job_record(
                experiment_id, join_job_id)
            self.join_job_record = JoinJobRecord.load_from_ddb_record(
                join_job_record)
        except Exception as e:
            logger.error("Unhandled Exception! " + str(e))
            raise UnhandledWorkflowException(
                "Something went wrong while creating a new join job")