def __init__( self, model_db_client: ModelDbClient, experiment_id, model_id, image=None, role=None, instance_config={}, boto_session=None, algor_config={}, train_state=None, evaluation_job_name=None, eval_state=None, eval_scores={}, input_model_id=None, rl_estimator=None, input_data_s3_prefix=None, manifest_file_path=None, eval_data_s3_path=None, s3_model_output_path=None, training_start_time=None, training_end_time=None, ): """Initialize a model entity in the current experiment Args: model_db_client (ModelDBClient): A DynamoDB client to query the model table. The 'Model' entity use this client to read/update the model state. experiment_id (str): A unique id for the experiment. The created/loaded model will be associated with the given experiment. model_id (str): Aa unique id for the model. The model table uses model id to manage associated model metadata. image (str): The container image to use for training/evaluation. role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs will use this role to access AWS resources. instance_config (dict): A dictionary that specify the resource configuration for the model training/evaluation job. boto_session (boto3.session.Session): A session stores configuration state and allows you to create service clients and resources. algor_config (dict): A dictionary that specify the algorithm type and hyper parameters of the training/evaluation job. train_state (str): State of the model training job. evaluation_job_name (str): Job name for Latest Evaluation Job for this model eval_state (str): State of the model evaluation job. input_model_id (str): A unique model id to specify which model to use as a pre-trained model for the model training job. rl_estimator (sagemaker.rl.estimator.RLEstimator): A Sagemaker RLEstimator entity that handle Reinforcement Learning (RL) execution within a SageMaker Training Job. input_data_s3_prefix (str): Input data path for the data source of the model training job. s3_model_output_path (str): Output data path of model artifact for the model training job. training_start_time (str): Starting timestamp of the model training job. training_end_time (str): Finished timestamp of the model training job. Returns: orchestrator.model_manager.ModelManager: A ``Model`` object associated with the given experiment. """ self.model_db_client = model_db_client self.experiment_id = experiment_id self.model_id = model_id # Currently we are not storing image/role and other model params in ModelDb self.image = image self.role = role self.instance_config = instance_config self.algor_config = algor_config # load configs self.instance_type = self.instance_config.get("instance_type", "local") self.instance_count = self.instance_config.get("instance_count", 1) self.algor_params = self.algor_config.get("algorithms_parameters", {}) # create a local ModelRecord object. self.model_record = ModelRecord( experiment_id, model_id, train_state, evaluation_job_name, eval_state, eval_scores, input_model_id, input_data_s3_prefix, manifest_file_path, eval_data_s3_path, s3_model_output_path, training_start_time, training_end_time, ) # try to save this record file. if it throws RecordAlreadyExistsException # reload the record from ModelDb, and recreate try: self.model_db_client.create_new_model_record( self.model_record.to_ddb_record()) except RecordAlreadyExistsException: logger.debug("Model already exists. Reloading from model record.") model_record = self.model_db_client.get_model_record( experiment_id, model_id) self.model_record = ModelRecord.load_from_ddb_record(model_record) except Exception as e: logger.error("Unhandled Exception! " + str(e)) raise UnhandledWorkflowException( "Something went wrong while creating a new model") if boto_session is None: boto_session = boto3.Session() self.boto_session = boto_session if self.instance_type == "local": self.sagemaker_session = LocalSession() else: self.sagemaker_session = sagemaker.session.Session( self.boto_session) self.sagemaker_client = self.sagemaker_session.sagemaker_client
def __init__(self, join_db_client: JoinDbClient, experiment_id, join_job_id, current_state=None, input_obs_data_s3_path=None, obs_start_time=None, obs_end_time=None, input_reward_data_s3_path=None, output_joined_train_data_s3_path=None, output_joined_eval_data_s3_path=None, join_query_ids=[], boto_session=None): """Initialize a joining job entity in the current experiment Args: join_db_client (JoinDbClient): A DynamoDB client to query the joining job table. The 'JoinJob' entity use this client to read/update the job state. experiment_id (str): A unique id for the experiment. The created/loaded joining job will be associated with the experiment. join_job_id (str): Aa unique id for the join job. The join job table uses join_job_id to manage associated job metadata. current_state (str): Current state of the joining job input_obs_data_s3_path (str): Input S3 data path for observation data obs_start_time (datetime): Datetime object to specify starting time of the observation data obs_end_time (datetime): Datetime object to specify ending time of the observation data input_reward_data_s3_path (str): S3 data path for rewards data output_joined_train_data_s3_path (str): Output S3 data path for training data split output_joined_eval_data_s3_path (str): Output S3 data path for evaluation data split join_query_ids (str): Athena join query ids for the joining requests boto_session (boto3.session.Session): A session stores configuration state and allows you to create service clients and resources. Return: orchestrator.join_manager.JoinManager: A ``JoinJob`` object associated with the given experiment. """ self.join_db_client = join_db_client self.experiment_id = experiment_id self.join_job_id = join_job_id if boto_session is None: boto_session = boto3.Session() self.boto_session = boto_session # formatted athena table name self.obs_table_partitioned = self._formatted_table_name( f"obs-{experiment_id}-partitioned") self.obs_table_non_partitioned = self._formatted_table_name( f"obs-{experiment_id}") self.rewards_table = self._formatted_table_name( f"rewards-{experiment_id}") self.query_s3_output_bucket = self._create_athena_s3_bucket_if_not_exist( ) self.athena_client = self.boto_session.client("athena") # create a local JoinJobRecord object. self.join_job_record = JoinJobRecord( experiment_id, join_job_id, current_state, input_obs_data_s3_path, obs_start_time, obs_end_time, input_reward_data_s3_path, output_joined_train_data_s3_path, output_joined_eval_data_s3_path, join_query_ids) # create obs partitioned/non-partitioned table if not exists if input_obs_data_s3_path and input_obs_data_s3_path != "local-join-does-not-apply": self._create_obs_table_if_not_exist() # create reward table if not exists if input_reward_data_s3_path and input_reward_data_s3_path != "local-join-does-not-apply": self._create_rewards_table_if_not_exist() # add partitions if input_obs_time_window is not None if obs_start_time and obs_end_time: self._add_time_partitions(obs_start_time, obs_end_time) # try to save this record file. if it throws RecordAlreadyExistsException # reload the record from JoinJobDb, and recreate try: self.join_db_client.create_new_join_job_record( self.join_job_record.to_ddb_record()) except RecordAlreadyExistsException: logger.debug( "Join job already exists. Reloading from join job record.") join_job_record = self.join_db_client.get_join_job_record( experiment_id, join_job_id) self.join_job_record = JoinJobRecord.load_from_ddb_record( join_job_record) except Exception as e: logger.error("Unhandled Exception! " + str(e)) raise UnhandledWorkflowException( "Something went wrong while creating a new join job")