def get_chkpoint_num(self, agent_key):
     try:
         s3_client = self._get_client()
         # If there is a lock file return -1 since it means the trainer has the lock
         response = s3_client.list_objects_v2(Bucket=self.params.buckets[agent_key],
                                              Prefix=self._get_s3_key(SyncFiles.LOCKFILE.value, agent_key))
         chkpoint_num = -1
         if "Contents" not in response:
             base_checkpoint_dir = self.params.base_checkpoint_dir
             checkpoint_dir = base_checkpoint_dir if len(self.graph_manager.agents_params) == 1 else os.path.join(base_checkpoint_dir, agent_key)
             if not os.path.exists(checkpoint_dir):
                 os.makedirs(checkpoint_dir)
             state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
             s3_client.download_file(Bucket=self.params.buckets[agent_key],
                                     Key=self._get_s3_key(state_file.filename, agent_key),
                                     Filename=state_file.path)
             checkpoint_state = state_file.read()
             if checkpoint_state is not None:
                 chkpoint_num = checkpoint_state.num
         return chkpoint_num
     except botocore.exceptions.ClientError:
         log_and_exit("Unable to download checkpoint",
                      SIMAPP_S3_DATA_STORE_EXCEPTION,
                      SIMAPP_EVENT_ERROR_CODE_400)
     except Exception:
         log_and_exit("Unable to download checkpoint",
                      SIMAPP_S3_DATA_STORE_EXCEPTION,
                      SIMAPP_EVENT_ERROR_CODE_500)
    def _save_to_store(self, checkpoint_dir):
        """
        save_to_store() uploads the policy checkpoint, gifs and videos to the S3 data store. It reads the checkpoint state files and
        uploads only the latest checkpoint files to S3. It is used by the trainer in Coach when used in the distributed mode.
        """
        try:
            # remove lock file if it exists
            self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)

            # Acquire lock
            self.mc.put_object(self.params.bucket_name, SyncFiles.LOCKFILE.value, io.BytesIO(b''), 0)

            state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
            if state_file.exists():
                ckpt_state = state_file.read()
                checkpoint_file = None
                for root, dirs, files in os.walk(checkpoint_dir):
                    for filename in files:
                        if filename == CheckpointStateFile.checkpoint_state_filename:
                            checkpoint_file = (root, filename)
                            continue
                        if filename.startswith(ckpt_state.name):
                            abs_name = os.path.abspath(os.path.join(root, filename))
                            rel_name = os.path.relpath(abs_name, checkpoint_dir)
                            self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)

                abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
                rel_name = os.path.relpath(abs_name, checkpoint_dir)
                self.mc.fput_object(self.params.bucket_name, rel_name, abs_name)

            # upload Finished if present
            if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)):
                self.mc.put_object(self.params.bucket_name, SyncFiles.FINISHED.value, io.BytesIO(b''), 0)

            # upload Ready if present
            if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)):
                self.mc.put_object(self.params.bucket_name, SyncFiles.TRAINER_READY.value, io.BytesIO(b''), 0)

            # release lock
            self.mc.remove_object(self.params.bucket_name, SyncFiles.LOCKFILE.value)

            if self.params.expt_dir and os.path.exists(self.params.expt_dir):
                for filename in os.listdir(self.params.expt_dir):
                    if filename.endswith((".csv", ".json")):
                        self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, filename))

            if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'videos')):
                for filename in os.listdir(os.path.join(self.params.expt_dir, 'videos')):
                        self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'videos', filename))

            if self.params.expt_dir and os.path.exists(os.path.join(self.params.expt_dir, 'gifs')):
                for filename in os.listdir(os.path.join(self.params.expt_dir, 'gifs')):
                        self.mc.fput_object(self.params.bucket_name, filename, os.path.join(self.params.expt_dir, 'gifs', filename))

        except ResponseError as e:
            print("Got exception: %s\n while saving to S3", e)
Exemple #3
0
 def __init__(self,
              agent_name,
              s3_dict_metrics,
              deepracer_checkpoint_json,
              ckpnt_dir,
              run_phase_sink,
              use_model_picker=True):
     '''s3_dict_metrics - Dictionary containing the required s3 info for the metrics
                          bucket with keys specified by MetricsS3Keys
        deepracer_checkpoint_json - DeepracerCheckpointJson instance
        ckpnt_dir - Directory where the current checkpont is to be stored
        run_phase_sink - Sink to recieve notification of a change in run phase
        use_model_picker - Flag to whether to use model picker or not.
     '''
     self._agent_name_ = agent_name
     self._deepracer_checkpoint_json = deepracer_checkpoint_json
     self._s3_metrics = Metrics(
         bucket=s3_dict_metrics[MetricsS3Keys.METRICS_BUCKET.value],
         s3_key=s3_dict_metrics[MetricsS3Keys.METRICS_KEY.value],
         region_name=s3_dict_metrics[MetricsS3Keys.REGION.value])
     self._start_time_ = time.time()
     self._episode_ = 0
     self._episode_reward_ = 0.0
     self._progress_ = 0.0
     self._episode_status = ''
     self._metrics_ = list()
     self._is_eval_ = True
     self._eval_trials_ = 0
     self._checkpoint_state_ = CheckpointStateFile(ckpnt_dir)
     self._use_model_picker = use_model_picker
     self._eval_stats_dict_ = {'chkpnt_name': None, 'avg_eval_metric': None}
     self._best_chkpnt_stats = {
         'name': None,
         'avg_eval_metric': None,
         'time_stamp': time.time()
     }
     self._current_eval_best_model_metric_list_ = list()
     self.is_save_simtrace_enabled = rospy.get_param(
         'SIMTRACE_S3_BUCKET', None)
     self._best_model_metric_type = BestModelMetricType(
         rospy.get_param('BEST_MODEL_METRIC',
                         BestModelMetricType.PROGRESS.value).lower())
     self.track_data = TrackData.get_instance()
     run_phase_sink.register(self)
     # Create the agent specific directories needed for storing the metric files
     self._simtrace_local_path = SIMTRACE_TRAINING_LOCAL_PATH_FORMAT.format(
         self._agent_name_)
     simtrace_dirname = os.path.dirname(self._simtrace_local_path)
     if simtrace_dirname or not os.path.exists(simtrace_dirname):
         os.makedirs(simtrace_dirname)
     self._current_sim_time = 0
     rospy.Service("/{}/{}".format(self._agent_name_, "mp4_video_metrics"),
                   VideoMetricsSrv, self._handle_get_video_metrics)
     self._video_metrics = Mp4VideoMetrics.get_empty_dict()
     AbstractTracker.__init__(self, TrackerPriority.HIGH)
    def load_from_store(self):
        """
        load_from_store() downloads a new checkpoint from the S3 data store when it is not available locally. It is used
        by the rollout workers when using Coach in distributed mode.
        """
        try:
            state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))

            # wait until lock is removed
            while True:
                objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.LOCKFILE.value)

                if next(objects, None) is None:
                    try:
                        # fetch checkpoint state file from S3
                        self.mc.fget_object(self.params.bucket_name, state_file.filename, state_file.path)
                    except Exception as e:
                        continue
                    break
                time.sleep(10)

            # Check if there's a finished file
            objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.FINISHED.value)

            if next(objects, None) is not None:
                try:
                    self.mc.fget_object(
                        self.params.bucket_name, SyncFiles.FINISHED.value,
                        os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.FINISHED.value))
                    )
                except Exception as e:
                    pass

            # Check if there's a ready file
            objects = self.mc.list_objects_v2(self.params.bucket_name, SyncFiles.TRAINER_READY.value)

            if next(objects, None) is not None:
                try:
                    self.mc.fget_object(
                        self.params.bucket_name, SyncFiles.TRAINER_READY.value,
                        os.path.abspath(os.path.join(self.params.checkpoint_dir, SyncFiles.TRAINER_READY.value))
                    )
                except Exception as e:
                    pass

            checkpoint_state = state_file.read()
            if checkpoint_state is not None:
                objects = self.mc.list_objects_v2(self.params.bucket_name, prefix=checkpoint_state.name, recursive=True)
                for obj in objects:
                    filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, obj.object_name))
                    if not os.path.exists(filename):
                        self.mc.fget_object(obj.bucket_name, obj.object_name, filename)

        except ResponseError as e:
            print("Got exception: %s\n while loading from S3", e)
Exemple #5
0
 def __init__(self,
              agent_name,
              s3_dict_metrics,
              s3_dict_model,
              ckpnt_dir,
              run_phase_sink,
              use_model_picker=True):
     '''s3_dict_metrics - Dictionary containing the required s3 info for the metrics
                          bucket with keys specified by MetricsS3Keys
        s3_dict_model - Dictionary containing the required s3 info for the model
                        bucket, which is where the best model info will be saved with
                        keys specified by MetricsS3Keys
        ckpnt_dir - Directory where the current checkpont is to be stored
        run_phase_sink - Sink to recieve notification of a change in run phase
        use_model_picker - Flag to whether to use model picker or not.
     '''
     self._agent_name_ = agent_name
     self._s3_dict_metrics_ = s3_dict_metrics
     self._s3_dict_model_ = s3_dict_model
     self._start_time_ = time.time()
     self._episode_ = 0
     self._episode_reward_ = 0.0
     self._progress_ = 0.0
     self._episode_status = ''
     self._metrics_ = list()
     self._is_eval_ = True
     self._eval_trials_ = 0
     self._checkpoint_state_ = CheckpointStateFile(ckpnt_dir)
     self._use_model_picker = use_model_picker
     self._eval_stats_dict_ = {'chkpnt_name': None, 'avg_comp_pct': 0.0}
     self._best_chkpnt_stats = {
         'name': None,
         'avg_comp_pct': 0.0,
         'time_stamp': time.time()
     }
     self._current_eval_pct_list_ = list()
     self.is_save_simtrace_enabled = rospy.get_param(
         'SIMTRACE_S3_BUCKET', None)
     run_phase_sink.register(self)
     # Create the agent specific directories needed for storing the metric files
     simtrace_dirname = os.path.dirname(
         IterationDataLocalFileNames.SIM_TRACE_TRAINING_LOCAL_FILE.value)
     if not os.path.exists(
             os.path.join(ITERATION_DATA_LOCAL_FILE_PATH, self._agent_name_,
                          simtrace_dirname)):
         os.makedirs(
             os.path.join(ITERATION_DATA_LOCAL_FILE_PATH, self._agent_name_,
                          simtrace_dirname))
    def get_latest_checkpoint(self):
        try:
            filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, "latest_ckpt"))
            if not os.path.exists(self.params.checkpoint_dir):
                os.makedirs(self.params.checkpoint_dir)

            while True:
                s3_client = self._get_client()
                state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))

                # wait until lock is removed
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(SyncFiles.LOCKFILE.value))
                if "Contents" not in response:
                    try:
                        # fetch checkpoint state file from S3
                        s3_client.download_file(Bucket=self.params.bucket,
                                                Key=self._get_s3_key(state_file.filename),
                                                Filename=filename)
                    except Exception as e:
                        time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                        continue
                else:
                    time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                    continue

                return self._get_current_checkpoint_number(checkpoint_metadata_filepath=filename)

        except Exception as e:
            utils.json_format_logger("Exception [{}] occured while getting latest checkpoint from S3.".format(e),
                                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION, utils.SIMAPP_EVENT_ERROR_CODE_503))
Exemple #7
0
def wait_for_checkpoints(checkpoint_dirs, data_store=None, timeout=10):
    """
    block until there is a checkpoint in all of the checkpoint_dirs
    """
    chkpt_state_files = [
        CheckpointStateFile(checkpoint_dir)
        for checkpoint_dir in checkpoint_dirs
    ]
    for i in range(timeout):
        if data_store:
            data_store.load_from_store()
        all_agent_checkpoint_copied = all([
            chkpt_state_file.read() is not None
            for chkpt_state_file in chkpt_state_files
        ])
        if all_agent_checkpoint_copied:
            return
        time.sleep(10)

    # one last time
    all_agent_checkpoint_copied = all([
        chkpt_state_file.read() is not None
        for chkpt_state_file in chkpt_state_files
    ])
    if all_agent_checkpoint_copied:
        return

    log_and_exit("Checkpoint never found in {} : {}, waited {} seconds." \
                     .format(checkpoint_dirs, all_agent_checkpoint_copied, timeout),
                 SIMAPP_SIMULATION_WORKER_EXCEPTION,
                 SIMAPP_EVENT_ERROR_CODE_500)
def rename_checkpoints(checkpoint_dir, agent_name):
    ''' Helper method that rename the specific checkpoint in the CheckpointStateFile 
        to be scoped with agent_name
        checkpoint_dir - local checkpoint folder where the checkpoints and .checkpoint file is stored
        agent_name - name of the agent
    '''
    logger.info("Renaming checkpoint from checkpoint_dir: {} for agent: {}".format(checkpoint_dir, agent_name))
    state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))   
    checkpoint_name = str(state_file.read())
    tf_checkpoint_file = os.path.join(checkpoint_dir, "checkpoint")
    with open(tf_checkpoint_file, "w") as outfile:
        outfile.write("model_checkpoint_path: \"{}\"".format(checkpoint_name))
    
    with tf.Session() as sess:
        for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
            # Load the variable
            var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
            new_name = var_name
            # Set the new name
            # Replace agent/ or agent_#/ with {agent_name}/  
            new_name = re.sub('agent/|agent_\d+/', '{}/'.format(agent_name), new_name)
            # Rename the variable
            var = tf.Variable(var, name=new_name)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        renamed_checkpoint_path = os.path.join(TEMP_RENAME_FOLDER, checkpoint_name)
        logger.info('Saving updated checkpoint to {}'.format(renamed_checkpoint_path))
        saver.save(sess, renamed_checkpoint_path)
    # Remove the tensorflow 'checkpoint' file
    os.remove(tf_checkpoint_file)
    # Remove the old checkpoint from the checkpoint dir
    for file_name in os.listdir(checkpoint_dir):
        if checkpoint_name in file_name:
            os.remove(os.path.join(checkpoint_dir, file_name))

    # Copy the new checkpoint with renamed variable to the checkpoint dir
    for file_name in os.listdir(TEMP_RENAME_FOLDER):
        full_file_name = os.path.join(os.path.abspath(TEMP_RENAME_FOLDER), file_name)
        if os.path.isfile(full_file_name) and file_name != "checkpoint":
            shutil.copy(full_file_name, checkpoint_dir)

    # Remove files from temp_rename_folder
    shutil.rmtree(TEMP_RENAME_FOLDER)
    
    tf.reset_default_graph()
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
    """
    block until there is a checkpoint in checkpoint_dir
    """
    chkpt_state_file = CheckpointStateFile(checkpoint_dir)

    def wait():
        return chkpt_state_file.read() is not None

    wait_for(wait, data_store, timeout)
Exemple #10
0
    def __init__(
        self,
        bucket,
        s3_prefix,
        region_name="us-east-1",
        local_dir="./checkpoint/agent",
        max_retry_attempts=5,
        backoff_time_sec=1.0,
        log_and_cont: bool = False,
    ):
        """This class is for RL coach checkpoint file

        Args:
            bucket (str): S3 bucket string.
            s3_prefix (str): S3 prefix string.
            region_name (str): S3 region name.
                               Defaults to 'us-east-1'.
            local_dir (str, optional): Local file directory.
                                       Defaults to '.checkpoint/agent'.
            max_retry_attempts (int, optional): Maximum number of retry attempts for S3 download/upload.
                                                Defaults to 5.
            backoff_time_sec (float, optional): Backoff second between each retry.
                                                Defaults to 1.0.
            log_and_cont (bool, optional): Log the error and continue with the flow.
                                           Defaults to False.
        """
        if not bucket or not s3_prefix:
            log_and_exit(
                "checkpoint S3 prefix or bucket not available for S3. \
                         bucket: {}, prefix {}".format(bucket, s3_prefix),
                SIMAPP_SIMULATION_WORKER_EXCEPTION,
                SIMAPP_EVENT_ERROR_CODE_500,
            )
        self._bucket = bucket
        # coach checkpoint s3 key
        self._s3_key = os.path.normpath(
            os.path.join(s3_prefix, COACH_CHECKPOINT_POSTFIX))
        # coach checkpoint local path
        self._local_path = os.path.normpath(
            COACH_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir))
        # coach checkpoint local temp path
        self._temp_local_path = os.path.normpath(
            TEMP_COACH_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir))
        # old coach checkpoint s3 key to handle backward compatibility
        self._old_s3_key = os.path.normpath(
            os.path.join(s3_prefix, OLD_COACH_CHECKPOINT_POSTFIX))
        # old coach checkpoint local path to handle backward compatibility
        self._old_local_path = os.path.normpath(
            OLD_COACH_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir))
        # coach checkpoint state file from rl coach
        self._coach_checkpoint_state_file = CheckpointStateFile(
            os.path.dirname(self._local_path))
        self._s3_client = S3Client(region_name, max_retry_attempts,
                                   backoff_time_sec, log_and_cont)
Exemple #11
0
def wait_for_checkpoint(checkpoint_dir, data_store=None, timeout=10):
    """
    block until there is a checkpoint in checkpoint_dir
    """
    chkpt_state_file = CheckpointStateFile(checkpoint_dir)
    for i in range(timeout):
        if data_store:
            data_store.load_from_store()

        if chkpt_state_file.read() is not None:
            return
        time.sleep(10)

    # one last time
    if chkpt_state_file.read() is not None:
        return

    raise ValueError(
        ('Waited {timeout} seconds, but checkpoint never found in '
         '{checkpoint_dir}').format(
             timeout=timeout,
             checkpoint_dir=checkpoint_dir,
         ))
    def __init__(self,
                 bucket,
                 s3_prefix,
                 region_name='us-east-1',
                 s3_endpoint_url=None,
                 local_dir='./checkpoint/agent',
                 max_retry_attempts=5,
                 backoff_time_sec=1.0):
        '''This class is for RL coach checkpoint file

        Args:
            bucket (str): S3 bucket string
            s3_prefix (str): S3 prefix string
            region_name (str): S3 region name
            local_dir (str): local file directory
            max_retry_attempts (int): maximum number of retry attempts for S3 download/upload
            backoff_time_sec (float): backoff second between each retry
        '''
        if not bucket or not s3_prefix:
            log_and_exit(
                "checkpoint S3 prefix or bucket not available for S3. \
                         bucket: {}, prefix {}".format(bucket, s3_prefix),
                SIMAPP_SIMULATION_WORKER_EXCEPTION,
                SIMAPP_EVENT_ERROR_CODE_500)
        self._bucket = bucket
        # coach checkpoint s3 key
        self._s3_key = os.path.normpath(
            os.path.join(s3_prefix, COACH_CHECKPOINT_POSTFIX))
        # coach checkpoint local path
        self._local_path = os.path.normpath(
            COACH_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir))
        # coach checkpoint local temp path
        self._temp_local_path = os.path.normpath(
            TEMP_COACH_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir))
        # old coach checkpoint s3 key to handle backward compatibility
        self._old_s3_key = os.path.normpath(
            os.path.join(s3_prefix, OLD_COACH_CHECKPOINT_POSTFIX))
        # old coach checkpoint local path to handle backward compatibility
        self._old_local_path = os.path.normpath(
            OLD_COACH_CHECKPOINT_LOCAL_PATH_FORMAT.format(local_dir))
        # coach checkpoint state file from rl coach
        self._coach_checkpoint_state_file = CheckpointStateFile(
            os.path.dirname(self._local_path))
        self._s3_client = S3Client(region_name, s3_endpoint_url,
                                   max_retry_attempts, backoff_time_sec)
Exemple #13
0
class TrainingMetrics(MetricsInterface, ObserverInterface, AbstractTracker):
    '''This class is responsible for uploading training metrics to s3'''
    def __init__(self,
                 agent_name,
                 s3_dict_metrics,
                 deepracer_checkpoint_json,
                 ckpnt_dir,
                 run_phase_sink,
                 use_model_picker=True):
        '''s3_dict_metrics - Dictionary containing the required s3 info for the metrics
                             bucket with keys specified by MetricsS3Keys
           deepracer_checkpoint_json - DeepracerCheckpointJson instance
           ckpnt_dir - Directory where the current checkpont is to be stored
           run_phase_sink - Sink to recieve notification of a change in run phase
           use_model_picker - Flag to whether to use model picker or not.
        '''
        self._agent_name_ = agent_name
        self._deepracer_checkpoint_json = deepracer_checkpoint_json
        self._s3_metrics = Metrics(
            bucket=s3_dict_metrics[MetricsS3Keys.METRICS_BUCKET.value],
            s3_key=s3_dict_metrics[MetricsS3Keys.METRICS_KEY.value],
            region_name=s3_dict_metrics[MetricsS3Keys.REGION.value])
        self._start_time_ = time.time()
        self._episode_ = 0
        self._episode_reward_ = 0.0
        self._progress_ = 0.0
        self._episode_status = ''
        self._metrics_ = list()
        self._is_eval_ = True
        self._eval_trials_ = 0
        self._checkpoint_state_ = CheckpointStateFile(ckpnt_dir)
        self._use_model_picker = use_model_picker
        self._eval_stats_dict_ = {'chkpnt_name': None, 'avg_eval_metric': None}
        self._best_chkpnt_stats = {
            'name': None,
            'avg_eval_metric': None,
            'time_stamp': time.time()
        }
        self._current_eval_best_model_metric_list_ = list()
        self.is_save_simtrace_enabled = rospy.get_param(
            'SIMTRACE_S3_BUCKET', None)
        self._best_model_metric_type = BestModelMetricType(
            rospy.get_param('BEST_MODEL_METRIC',
                            BestModelMetricType.PROGRESS.value).lower())
        self.track_data = TrackData.get_instance()
        run_phase_sink.register(self)
        # Create the agent specific directories needed for storing the metric files
        self._simtrace_local_path = SIMTRACE_TRAINING_LOCAL_PATH_FORMAT.format(
            self._agent_name_)
        simtrace_dirname = os.path.dirname(self._simtrace_local_path)
        if simtrace_dirname or not os.path.exists(simtrace_dirname):
            os.makedirs(simtrace_dirname)
        self._current_sim_time = 0
        rospy.Service("/{}/{}".format(self._agent_name_, "mp4_video_metrics"),
                      VideoMetricsSrv, self._handle_get_video_metrics)
        self._video_metrics = Mp4VideoMetrics.get_empty_dict()
        AbstractTracker.__init__(self, TrackerPriority.HIGH)

    def update_tracker(self, delta_time, sim_time):
        """
        Callback when sim time is updated

        Args:
            delta_time (float): time diff from last call
            sim_time (Clock): simulation time
        """
        self._current_sim_time = sim_time.clock.secs + 1.e-9 * sim_time.clock.nsecs

    def reset(self):
        self._start_time_ = self._current_sim_time
        self._episode_reward_ = 0.0
        self._progress_ = 0.0

    def append_episode_metrics(self):
        self._episode_ += 1 if not self._is_eval_ else 0
        self._eval_trials_ += 1 if not self._is_eval_ else 0
        training_metric = dict()
        training_metric['reward_score'] = int(round(self._episode_reward_))
        training_metric['metric_time'] = int(
            round(self._current_sim_time * 1000))
        training_metric['start_time'] = int(round(self._start_time_ * 1000))
        training_metric['elapsed_time_in_milliseconds'] = \
            int(round((self._current_sim_time - self._start_time_) * 1000))
        training_metric['episode'] = int(self._episode_)
        training_metric['trial'] = int(self._eval_trials_)
        training_metric[
            'phase'] = 'evaluation' if self._is_eval_ else 'training'
        training_metric['completion_percentage'] = int(self._progress_)
        training_metric[
            'episode_status'] = EpisodeStatus.get_episode_status_label(
                self._episode_status)
        self._metrics_.append(training_metric)

    def upload_episode_metrics(self):
        json_metrics = json.dumps({
            'metrics':
            self._metrics_,
            'version':
            METRICS_VERSION,
            'best_model_metric':
            self._best_model_metric_type.value
        })
        self._s3_metrics.persist(body=json_metrics,
                                 s3_kms_extra_args=get_s3_kms_extra_args())
        if self._is_eval_:
            if self._best_model_metric_type == BestModelMetricType.REWARD:
                self._current_eval_best_model_metric_list_.append(
                    self._episode_reward_)
            else:
                self._current_eval_best_model_metric_list_.append(
                    self._progress_)

    def upload_step_metrics(self, metrics):
        self._progress_ = metrics[StepMetrics.PROG.value]
        self._episode_status = metrics[StepMetrics.EPISODE_STATUS.value]
        self._episode_reward_ += metrics[StepMetrics.REWARD.value]
        if not self._is_eval_:
            metrics[StepMetrics.EPISODE.value] = self._episode_
            StepMetrics.validate_dict(metrics)
            sim_trace_log(metrics)
            if self.is_save_simtrace_enabled:
                write_simtrace_to_local_file(self._simtrace_local_path,
                                             metrics)

    def update(self, data):
        self._is_eval_ = data != RunPhase.TRAIN

        if not self._is_eval_ and self._use_model_picker:
            if self._eval_stats_dict_['chkpnt_name'] is None:
                self._eval_stats_dict_[
                    'chkpnt_name'] = self._checkpoint_state_.read().name

            self._eval_trials_ = 0
            mean_metric = statistics.mean(
                self._current_eval_best_model_metric_list_
            ) if self._current_eval_best_model_metric_list_ else None
            msg_format = '[BestModelSelection] Number of evaluations: {} Evaluation episode {}: {}'
            LOGGER.info(
                msg_format.format(
                    len(self._current_eval_best_model_metric_list_),
                    self._best_model_metric_type.value,
                    self._current_eval_best_model_metric_list_))
            LOGGER.info(
                '[BestModelSelection] Evaluation episode {} mean: {}'.format(
                    self._best_model_metric_type.value, mean_metric))
            self._current_eval_best_model_metric_list_.clear()

            time_stamp = self._current_sim_time
            if self._eval_stats_dict_['avg_eval_metric'] is None or \
                    mean_metric >= self._eval_stats_dict_['avg_eval_metric']:
                msg_format = '[BestModelSelection] current {0} mean: {1} >= best {0} mean: {2}'
                LOGGER.info(
                    msg_format.format(
                        self._best_model_metric_type.value, mean_metric,
                        self._eval_stats_dict_['avg_eval_metric']))
                msg_format = '[BestModelSelection] Updating the best checkpoint to "{}" from "{}".'
                LOGGER.info(
                    msg_format.format(self._eval_stats_dict_['chkpnt_name'],
                                      self._best_chkpnt_stats['name']))
                self._eval_stats_dict_['avg_eval_metric'] = mean_metric
                self._best_chkpnt_stats = {
                    'name': self._eval_stats_dict_['chkpnt_name'],
                    'avg_eval_metric': mean_metric,
                    'time_stamp': time_stamp
                }
            last_chkpnt_stats = {
                'name': self._eval_stats_dict_['chkpnt_name'],
                'avg_eval_metric': mean_metric,
                'time_stamp': time_stamp
            }
            self._deepracer_checkpoint_json.persist(
                body=json.dumps({
                    BEST_CHECKPOINT: self._best_chkpnt_stats,
                    LAST_CHECKPOINT: last_chkpnt_stats
                }),
                s3_kms_extra_args=get_s3_kms_extra_args())
            # Update the checkpoint name to the new checkpoint being used for training that will
            # then be evaluated, note this class gets notfied when the system is put into a
            # training phase and assumes that a training phase only starts when a new check point
            # is avaialble
            self._eval_stats_dict_[
                'chkpnt_name'] = self._checkpoint_state_.read().name

    def update_mp4_video_metrics(self, metrics):
        agent_x, agent_y = metrics[StepMetrics.X.value], metrics[
            StepMetrics.Y.value]
        self._video_metrics[Mp4VideoMetrics.LAP_COUNTER.value] = 0
        self._video_metrics[
            Mp4VideoMetrics.COMPLETION_PERCENTAGE.value] = self._progress_
        # For continuous race, MP4 video will display the total reset counter for the entire race
        # For non-continuous race, MP4 video will display reset counter per lap
        self._video_metrics[Mp4VideoMetrics.RESET_COUNTER.value] = 0

        self._video_metrics[Mp4VideoMetrics.THROTTLE.value] = 0
        self._video_metrics[Mp4VideoMetrics.STEERING.value] = 0
        self._video_metrics[Mp4VideoMetrics.BEST_LAP_TIME.value] = 0
        self._video_metrics[Mp4VideoMetrics.TOTAL_EVALUATION_TIME.value] = 0
        self._video_metrics[Mp4VideoMetrics.DONE.value] = metrics[
            StepMetrics.DONE.value]
        self._video_metrics[Mp4VideoMetrics.X.value] = agent_x
        self._video_metrics[Mp4VideoMetrics.Y.value] = agent_y

        object_poses = [pose for object_name, pose in self.track_data.object_poses.items()\
                        if not object_name.startswith('racecar')]
        object_locations = []
        for pose in object_poses:
            point = Point32()
            point.x, point.y, point.z = pose.position.x, pose.position.y, 0
            object_locations.append(point)
        self._video_metrics[
            Mp4VideoMetrics.OBJECT_LOCATIONS.value] = object_locations

    def _handle_get_video_metrics(self, req):
        return VideoMetricsSrvResponse(
            self._video_metrics[Mp4VideoMetrics.LAP_COUNTER.value],
            self._video_metrics[Mp4VideoMetrics.COMPLETION_PERCENTAGE.value],
            self._video_metrics[Mp4VideoMetrics.RESET_COUNTER.value],
            self._video_metrics[Mp4VideoMetrics.THROTTLE.value],
            self._video_metrics[Mp4VideoMetrics.STEERING.value],
            self._video_metrics[Mp4VideoMetrics.BEST_LAP_TIME.value],
            self._video_metrics[Mp4VideoMetrics.TOTAL_EVALUATION_TIME.value],
            self._video_metrics[Mp4VideoMetrics.DONE.value],
            self._video_metrics[Mp4VideoMetrics.X.value],
            self._video_metrics[Mp4VideoMetrics.Y.value],
            self._video_metrics[Mp4VideoMetrics.OBJECT_LOCATIONS.value],
            self._video_metrics[Mp4VideoMetrics.EPISODE_STATUS.value],
            self._video_metrics[Mp4VideoMetrics.PAUSE_DURATION.value])
    def save_to_store(self):
        try:
            s3_client = self._get_client()
            checkpoint_dir = self.params.checkpoint_dir

            # remove lock file if it exists
            s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(SyncFiles.LOCKFILE.value))

            # acquire lock
            s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                     Bucket=self.params.bucket,
                                     Key=self._get_s3_key(SyncFiles.LOCKFILE.value))

            state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
            ckpt_state = None
            if state_file.exists():
                ckpt_state = state_file.read()
                checkpoint_file = None
                num_files_uploaded = 0
                for root, _, files in os.walk(checkpoint_dir):
                    for filename in files:
                        if filename == CheckpointStateFile.checkpoint_state_filename:
                            checkpoint_file = (root, filename)
                            continue
                        if filename.startswith(ckpt_state.name):
                            abs_name = os.path.abspath(os.path.join(root, filename))
                            rel_name = os.path.relpath(abs_name, checkpoint_dir)
                            s3_client.upload_file(Filename=abs_name,
                                                  Bucket=self.params.bucket,
                                                  Key=self._get_s3_key(rel_name))
                            num_files_uploaded += 1
                logger.info("Uploaded {} files for checkpoint {}".format(num_files_uploaded, ckpt_state.num))

                abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
                rel_name = os.path.relpath(abs_name, checkpoint_dir)
                s3_client.upload_file(Filename=abs_name,
                                      Bucket=self.params.bucket,
                                      Key=self._get_s3_key(rel_name))

            # upload Finished if present
            if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)):
                s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                         Bucket=self.params.bucket,
                                         Key=self._get_s3_key(SyncFiles.FINISHED.value))

            # upload Ready if present
            if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)):
                s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                         Bucket=self.params.bucket,
                                         Key=self._get_s3_key(SyncFiles.TRAINER_READY.value))

            # release lock
            s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(SyncFiles.LOCKFILE.value))

            # Upload the frozen graph which is used for deployment
            if self.graph_manager:
                self.write_frozen_graph(self.graph_manager)
                # upload the model_<ID>.pb to S3. NOTE: there's no cleanup as we don't know the best checkpoint
                for agent_params in self.graph_manager.agents_params:
                    iteration_id = self.graph_manager.level_managers[0].agents[agent_params.name].training_iteration
                    frozen_graph_fpath = os.path.join(SM_MODEL_OUTPUT_DIR, agent_params.name, "model.pb")
                    frozen_name = "model_{}.pb".format(iteration_id)
                    frozen_graph_s3_name = frozen_name if len(self.graph_manager.agents_params) == 1 \
                        else os.path.join(agent_params.name, frozen_name)
                    s3_client.upload_file(Filename=frozen_graph_fpath,
                                          Bucket=self.params.bucket,
                                          Key=self._get_s3_key(frozen_graph_s3_name))
                    logger.info("saved intermediate frozen graph: {}".format(self._get_s3_key(frozen_graph_s3_name)))

            # Clean up old checkpoints
            if ckpt_state:
                checkpoint_number_to_delete = ckpt_state.num - NUM_MODELS_TO_KEEP

                # List all the old checkpoint files to be deleted
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(""))
                if "Contents" in response:
                    for obj in response["Contents"]:
                        _, basename = os.path.split(obj["Key"])
                        if basename.startswith("{}_".format(checkpoint_number_to_delete)):
                            s3_client.delete_object(Bucket=self.params.bucket,
                                                    Key=obj["Key"])

        except botocore.exceptions.ClientError as e:
            utils.json_format_logger("Unable to upload checkpoint to {}, {}"
                                     .format(self.params.bucket, e.response['Error']['Code']),
                                     **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                   utils.SIMAPP_EVENT_ERROR_CODE_400))
            utils.simapp_exit_gracefully()
        except Exception as e:
            utils.json_format_logger("Unable to upload checkpoint to {}, {}"
                                     .format(self.params.bucket, e),
                                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                     utils.SIMAPP_EVENT_ERROR_CODE_500))
            utils.simapp_exit_gracefully()
    def load_from_store(self, expected_checkpoint_number=-1):
        try:
            s3_client = self._get_client()
            base_checkpoint_dir = self.params.base_checkpoint_dir
            for agent_key, bucket in self.params.buckets.items():
                checkpoint_dir = base_checkpoint_dir if len(self.graph_manager.agents_params) == 1 else os.path.join(base_checkpoint_dir, agent_key)
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                while True:
                    s3_client = self._get_client()
                    state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))

                    # wait until lock is removed
                    response = s3_client.list_objects_v2(Bucket=bucket,
                                                         Prefix=self._get_s3_key(SyncFiles.LOCKFILE.value, agent_key))
                    if "Contents" not in response or self.ignore_lock:
                        try:
                            checkpoint_file_path = os.path.abspath(os.path.join(checkpoint_dir,
                                                                                state_file.path))
                            # fetch checkpoint state file from S3
                            s3_client.download_file(Bucket=bucket,
                                                    Key=self._get_s3_key(state_file.filename, agent_key),
                                                    Filename=checkpoint_file_path)
                        except botocore.exceptions.ClientError:
                            if self.ignore_lock:
                                log_and_exit("Checkpoint not found",
                                             SIMAPP_S3_DATA_STORE_EXCEPTION,
                                             SIMAPP_EVENT_ERROR_CODE_400)
                            time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                            continue
                        except Exception:
                            if self.ignore_lock:
                                log_and_exit("Checkpoint not found",
                                             SIMAPP_S3_DATA_STORE_EXCEPTION,
                                             SIMAPP_EVENT_ERROR_CODE_500)
                            time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                            continue
                    else:
                        time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                        continue

                    # check if there's a Finished file
                    response = s3_client.list_objects_v2(Bucket=bucket,
                                                         Prefix=self._get_s3_key(SyncFiles.FINISHED.value, agent_key))
                    if "Contents" in response:
                        try:
                            finished_file_path = os.path.abspath(os.path.join(checkpoint_dir,
                                                                              SyncFiles.FINISHED.value))
                            s3_client.download_file(Bucket=bucket,
                                                    Key=self._get_s3_key(SyncFiles.FINISHED.value, agent_key),
                                                    Filename=finished_file_path)
                        except Exception:
                            pass

                    # check if there's a Ready file
                    response = s3_client.list_objects_v2(Bucket=bucket,
                                                         Prefix=self._get_s3_key(SyncFiles.TRAINER_READY.value, agent_key))
                    if "Contents" in response:
                        try:
                            ready_file_path = os.path.abspath(os.path.join(checkpoint_dir,
                                                                           SyncFiles.TRAINER_READY.value))
                            s3_client.download_file(Bucket=bucket,
                                                    Key=self._get_s3_key(SyncFiles.TRAINER_READY.value, agent_key),
                                                    Filename=ready_file_path)
                        except Exception:
                            pass

                    checkpoint_state = state_file.read()
                    if checkpoint_state is not None:

                        # if we get a checkpoint that is older that the expected checkpoint, we wait for
                        #  the new checkpoint to arrive.

                        if checkpoint_state.num < expected_checkpoint_number:
                            time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                            continue

                        response = s3_client.list_objects_v2(Bucket=bucket,
                                                             Prefix=self._get_s3_key("", agent_key))
                        if "Contents" in response:
                            # Check to see if the desired checkpoint is in the bucket
                            has_chkpnt = any(list(map(lambda obj: os.path.split(obj['Key'])[1].\
                                                                startswith(checkpoint_state.name),
                                                      response['Contents'])))
                            for obj in response["Contents"]:
                                full_key_prefix = os.path.normpath(self.key_prefixes[agent_key]) + "/"
                                filename = os.path.abspath(os.path.join(checkpoint_dir,
                                                                        obj["Key"].\
                                                                        replace(full_key_prefix, "")))
                                dirname, basename = os.path.split(filename)
                                # Download all the checkpoints but not the frozen models since they
                                # are not necessary
                                _, file_extension = os.path.splitext(obj["Key"])
                                if file_extension != '.pb' \
                                and (basename.startswith(checkpoint_state.name) or not has_chkpnt):
                                    if not os.path.exists(dirname):
                                        os.makedirs(dirname)
                                    s3_client.download_file(Bucket=bucket,
                                                            Key=obj["Key"],
                                                            Filename=filename)
                            # Change the coach checkpoint file to point to the latest available checkpoint,
                            # also log that we are changing the checkpoint.
                            if not has_chkpnt:
                                all_ckpnts = _filter_checkpoint_files(os.listdir(checkpoint_dir))
                                if all_ckpnts:
                                    LOG.info("%s not in s3 bucket, downloading all checkpoints \
                                                and using %s", checkpoint_state.name, all_ckpnts[-1])
                                    state_file.write(all_ckpnts[-1])
                                else:
                                    log_and_exit("No checkpoint files",
                                                 SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                 SIMAPP_EVENT_ERROR_CODE_400)
                    break
            return True

        except botocore.exceptions.ClientError:
            log_and_exit("Unable to download checkpoint",
                         SIMAPP_S3_DATA_STORE_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_400)
        except Exception:
            log_and_exit("Unable to download checkpoint",
                         SIMAPP_S3_DATA_STORE_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
 def __init__(self,
              agent_name,
              s3_dict_metrics,
              deepracer_checkpoint_json,
              ckpnt_dir,
              run_phase_sink,
              use_model_picker=True):
     '''s3_dict_metrics - Dictionary containing the required s3 info for the metrics
                          bucket with keys specified by MetricsS3Keys
        deepracer_checkpoint_json - DeepracerCheckpointJson instance
        ckpnt_dir - Directory where the current checkpont is to be stored
        run_phase_sink - Sink to recieve notification of a change in run phase
        use_model_picker - Flag to whether to use model picker or not.
     '''
     self._agent_name_ = agent_name
     self._deepracer_checkpoint_json = deepracer_checkpoint_json
     self._s3_metrics = Metrics(
         bucket=s3_dict_metrics[MetricsS3Keys.METRICS_BUCKET.value],
         s3_key=s3_dict_metrics[MetricsS3Keys.METRICS_KEY.value],
         region_name=s3_dict_metrics[MetricsS3Keys.REGION.value],
         s3_endpoint_url=s3_dict_metrics[MetricsS3Keys.ENDPOINT_URL.value])
     self._start_time_ = time.time()
     self._episode_ = 0
     self._episode_reward_ = 0.0
     self._progress_ = 0.0
     self._episode_status = ''
     self._metrics_ = list()
     self._is_eval_ = True
     self._eval_trials_ = 0
     self._checkpoint_state_ = CheckpointStateFile(ckpnt_dir)
     self._use_model_picker = use_model_picker
     self._eval_stats_dict_ = {'chkpnt_name': None, 'avg_eval_metric': None}
     self._best_chkpnt_stats = {
         'name': None,
         'avg_eval_metric': None,
         'time_stamp': time.time()
     }
     self._current_eval_best_model_metric_list_ = list()
     self.is_save_simtrace_enabled = rospy.get_param(
         'SIMTRACE_S3_BUCKET', None)
     self._best_model_metric_type = BestModelMetricType(
         rospy.get_param('BEST_MODEL_METRIC',
                         BestModelMetricType.PROGRESS.value).lower())
     self.track_data = TrackData.get_instance()
     run_phase_sink.register(self)
     # Create the agent specific directories needed for storing the metric files
     self._simtrace_local_path = SIMTRACE_TRAINING_LOCAL_PATH_FORMAT.format(
         self._agent_name_)
     simtrace_dirname = os.path.dirname(self._simtrace_local_path)
     # addressing mkdir and check directory race condition:
     # https://stackoverflow.com/questions/12468022/python-fileexists-error-when-making-directory/30174982#30174982
     # TODO: change this to os.makedirs(simtrace_dirname, exist_ok=True) when we migrate off python 2.7
     try:
         os.makedirs(simtrace_dirname)
     except OSError as e:
         if e.errno != errno.EEXIST:
             raise
         LOGGER.error("File already exist %s", simtrace_dirname)
     self._current_sim_time = 0
     rospy.Service("/{}/{}".format(self._agent_name_, "mp4_video_metrics"),
                   VideoMetricsSrv, self._handle_get_video_metrics)
     self._video_metrics = Mp4VideoMetrics.get_empty_dict()
     AbstractTracker.__init__(self, TrackerPriority.HIGH)
    def save_to_store(self):
        try:
            s3_client = self._get_client()
            base_checkpoint_dir = self.params.base_checkpoint_dir
            for agent_key, bucket in self.params.buckets.items():
                # remove lock file if it exists
                s3_client.delete_object(Bucket=bucket, Key=self._get_s3_key(SyncFiles.LOCKFILE.value, agent_key))

                # acquire lock
                s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                         Bucket=bucket,
                                         Key=self._get_s3_key(SyncFiles.LOCKFILE.value, agent_key))

                checkpoint_dir = base_checkpoint_dir if len(self.graph_manager.agents_params) == 1 else \
                    os.path.join(base_checkpoint_dir, agent_key)

                state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
                ckpt_state = None
                check_point_key_list = []
                if state_file.exists():
                    ckpt_state = state_file.read()
                    checkpoint_file = None
                    num_files_uploaded = 0
                    start_time = time.time()
                    for root, _, files in os.walk(checkpoint_dir):
                        for filename in files:
                            if filename == CheckpointStateFile.checkpoint_state_filename:
                                checkpoint_file = (root, filename)
                                continue
                            if filename.startswith(ckpt_state.name):
                                abs_name = os.path.abspath(os.path.join(root, filename))
                                rel_name = os.path.relpath(abs_name, checkpoint_dir)
                                s3_client.upload_file(Filename=abs_name,
                                                      Bucket=bucket,
                                                      Key=self._get_s3_key(rel_name, agent_key),
                                                      Config=boto3.s3.transfer.TransferConfig(multipart_threshold=1))
                                check_point_key_list.append(self._get_s3_key(rel_name, agent_key))
                                num_files_uploaded += 1
                    time_taken = time.time() - start_time
                    LOG.info("Uploaded %s files for checkpoint %s in %.2f seconds", num_files_uploaded, ckpt_state.num, time_taken)
                    if check_point_key_list:
                        self.delete_queues[agent_key].put(check_point_key_list)

                    abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
                    rel_name = os.path.relpath(abs_name, checkpoint_dir)
                    s3_client.upload_file(Filename=abs_name,
                                          Bucket=bucket,
                                          Key=self._get_s3_key(rel_name, agent_key))

                # upload Finished if present
                if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)):
                    s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                             Bucket=bucket,
                                             Key=self._get_s3_key(SyncFiles.FINISHED.value, agent_key))

                # upload Ready if present
                if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)):
                    s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                             Bucket=bucket,
                                             Key=self._get_s3_key(SyncFiles.TRAINER_READY.value, agent_key))

                # release lock
                s3_client.delete_object(Bucket=bucket,
                                        Key=self._get_s3_key(SyncFiles.LOCKFILE.value, agent_key))

                # Upload the frozen graph which is used for deployment
                if self.graph_manager:
                    # checkpoint state is always present for the checkpoint dir passed.
                    # We make same assumption while we get the best checkpoint in s3_metrics
                    checkpoint_num = ckpt_state.num
                    self.write_frozen_graph(self.graph_manager, agent_key, checkpoint_num)
                    frozen_name = "model_{}.pb".format(checkpoint_num)
                    frozen_graph_fpath = os.path.join(SM_MODEL_PB_TEMP_FOLDER, agent_key,
                                                      frozen_name)
                    frozen_graph_s3_name = frozen_name if len(self.graph_manager.agents_params) == 1 \
                        else os.path.join(agent_key, frozen_name)
                    # upload the model_<ID>.pb to S3.
                    s3_client.upload_file(Filename=frozen_graph_fpath,
                                          Bucket=bucket,
                                          Key=self._get_s3_key(frozen_graph_s3_name, agent_key))
                    LOG.info("saved intermediate frozen graph: %s", self._get_s3_key(frozen_graph_s3_name, agent_key))

                    # Copy the best checkpoint to the SM_MODEL_OUTPUT_DIR
                    copy_best_frozen_model_to_sm_output_dir(bucket,
                                                            self.params.s3_folders[agent_key],
                                                            self.params.aws_region,
                                                            os.path.join(SM_MODEL_PB_TEMP_FOLDER, agent_key),
                                                            os.path.join(SM_MODEL_OUTPUT_DIR, agent_key),
                                                            self.params.s3_endpoint_url)

                # Clean up old checkpoints
                if ckpt_state and self.delete_queues[agent_key].qsize() > NUM_MODELS_TO_KEEP:
                    best_checkpoint = get_best_checkpoint(bucket,
                                                          self.params.s3_folders[agent_key],
                                                          self.params.aws_region,
                                                          self.params.s3_endpoint_url)
                    while self.delete_queues[agent_key].qsize() > NUM_MODELS_TO_KEEP:
                        key_list = self.delete_queues[agent_key].get()
                        if best_checkpoint and all(list(map(lambda file_name: best_checkpoint in file_name,
                                                            [os.path.split(file)[-1] for file in key_list]))):
                            self.delete_queues[agent_key].put(key_list)
                        else:
                            delete_iteration_ids = set()
                            for key in key_list:
                                s3_client.delete_object(Bucket=bucket, Key=key)
                                # Get the name of the file in the checkpoint directory that has to be deleted
                                # and extract the iteration id out of the name
                                file_in_checkpoint_dir = os.path.split(key)[-1]
                                if len(file_in_checkpoint_dir.split("_Step")) > 0:
                                    delete_iteration_ids.add(file_in_checkpoint_dir.split("_Step")[0])
                            LOG.info("Deleting the frozen models in s3 for the iterations: %s",
                                     delete_iteration_ids)
                            # Delete the model_{}.pb files from the s3 bucket for the previous iterations
                            for iteration_id in list(delete_iteration_ids):
                                frozen_name = "model_{}.pb".format(iteration_id)
                                frozen_graph_s3_name = frozen_name if len(self.graph_manager.agents_params) == 1 \
                                    else os.path.join(agent_key, frozen_name)
                                s3_client.delete_object(Bucket=bucket,
                                                        Key=self._get_s3_key(frozen_graph_s3_name, agent_key))
        except botocore.exceptions.ClientError:
            log_and_exit("Unable to upload checkpoint",
                         SIMAPP_S3_DATA_STORE_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_400)
        except Exception:
            log_and_exit("Unable to upload checkpoint",
                         SIMAPP_S3_DATA_STORE_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
    def load_from_store(self, expected_checkpoint_number=-1):
        try:
            if not os.path.exists(self.params.checkpoint_dir):
                os.makedirs(self.params.checkpoint_dir)

            while True:
                s3_client = self._get_client()
                state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))

                # wait until lock is removed
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(SyncFiles.LOCKFILE.value))
                if "Contents" not in response:
                    try:
                        # fetch checkpoint state file from S3
                        s3_client.download_file(Bucket=self.params.bucket,
                                                Key=self._get_s3_key(state_file.filename),
                                                Filename=state_file.path)
                    except Exception as e:
                        time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                        continue
                else:
                    time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                    continue

                # check if there's a Finished file
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(SyncFiles.FINISHED.value))
                if "Contents" in response:
                    try:
                        finished_file_path = os.path.abspath(os.path.join(self.params.checkpoint_dir,
                                                                          SyncFiles.FINISHED.value))
                        s3_client.download_file(Bucket=self.params.bucket,
                                                Key=self._get_s3_key(SyncFiles.FINISHED.value),
                                                Filename=finished_file_path)
                    except Exception as e:
                        pass

                # check if there's a Ready file
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(SyncFiles.TRAINER_READY.value))
                if "Contents" in response:
                    try:
                        ready_file_path = os.path.abspath(os.path.join(self.params.checkpoint_dir,
                                                                       SyncFiles.TRAINER_READY.value))
                        s3_client.download_file(Bucket=self.params.bucket,
                                                Key=self._get_s3_key(SyncFiles.TRAINER_READY.value),
                                                Filename=ready_file_path)
                    except Exception as e:
                        pass

                checkpoint_state = state_file.read()
                if checkpoint_state is not None:

                    # if we get a checkpoint that is older that the expected checkpoint, we wait for
                    #  the new checkpoint to arrive.
                    if checkpoint_state.num < expected_checkpoint_number:
                        time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                        continue

                    response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                         Prefix=self._get_s3_key(""))
                    if "Contents" in response:
                        # Check to see if the desired checkpoint is in the bucket
                        has_chkpnt = any(list(map(lambda obj: os.path.split(obj['Key'])[1].\
                                                              startswith(checkpoint_state.name),
                                                  response['Contents'])))
                        for obj in response["Contents"]:
                            full_key_prefix = os.path.normpath(self.key_prefix) + "/"
                            filename = os.path.abspath(os.path.join(self.params.checkpoint_dir,
                                                                    obj["Key"].\
                                                                    replace(full_key_prefix, "")))
                            dirname, basename = os.path.split(filename)
                            # Download all the checkpoints but not the frozen models since they
                            # are not necessary
                            _, file_extension = os.path.splitext(obj["Key"])
                            if file_extension != '.pb' \
                            and (basename.startswith(checkpoint_state.name) or not has_chkpnt):
                                if not os.path.exists(dirname):
                                    os.makedirs(dirname)
                                s3_client.download_file(Bucket=self.params.bucket,
                                                        Key=obj["Key"],
                                                        Filename=filename)
                        # Change the coach checkpoint file to point to the latest available checkpoint,
                        # also log that we are changing the checkpoint.
                        if not has_chkpnt:
                            all_ckpnts = _filter_checkpoint_files(os.listdir(self.params.checkpoint_dir))
                            if all_ckpnts:
                                logger.info("%s not in s3 bucket, downloading all checkpoints \
                                            and using %s", checkpoint_state.name, all_ckpnts[-1])
                                state_file.write(all_ckpnts[-1])
                            else:
                                utils.json_format_logger("No checkpoint files found in {}".format(self.params.bucket),
                                                         **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                                       utils.SIMAPP_EVENT_ERROR_CODE_400))
                                utils.simapp_exit_gracefully()
                return True

        except botocore.exceptions.ClientError as e:
            utils.json_format_logger("Unable to download checkpoint from {}, {}"
                                     .format(self.params.bucket, e.response['Error']['Code']),
                                     **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                   utils.SIMAPP_EVENT_ERROR_CODE_400))
            utils.simapp_exit_gracefully()
        except Exception as e:
            utils.json_format_logger("Unable to download checkpoint from {}, {}"
                                     .format(self.params.bucket, e),
                                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                     utils.SIMAPP_EVENT_ERROR_CODE_500))
            utils.simapp_exit_gracefully()
Exemple #19
0
class TrainingMetrics(MetricsInterface, ObserverInterface):
    '''This class is responsible for uploading training metrics to s3'''
    def __init__(self,
                 agent_name,
                 s3_dict_metrics,
                 s3_dict_model,
                 ckpnt_dir,
                 run_phase_sink,
                 use_model_picker=True):
        '''s3_dict_metrics - Dictionary containing the required s3 info for the metrics
                             bucket with keys specified by MetricsS3Keys
           s3_dict_model - Dictionary containing the required s3 info for the model
                           bucket, which is where the best model info will be saved with
                           keys specified by MetricsS3Keys
           ckpnt_dir - Directory where the current checkpont is to be stored
           run_phase_sink - Sink to recieve notification of a change in run phase
           use_model_picker - Flag to whether to use model picker or not.
        '''
        self._agent_name_ = agent_name
        self._s3_dict_metrics_ = s3_dict_metrics
        self._s3_dict_model_ = s3_dict_model
        self._start_time_ = time.time()
        self._episode_ = 0
        self._episode_reward_ = 0.0
        self._progress_ = 0.0
        self._episode_status = ''
        self._metrics_ = list()
        self._is_eval_ = True
        self._eval_trials_ = 0
        self._checkpoint_state_ = CheckpointStateFile(ckpnt_dir)
        self._use_model_picker = use_model_picker
        self._eval_stats_dict_ = {'chkpnt_name': None, 'avg_comp_pct': 0.0}
        self._best_chkpnt_stats = {
            'name': None,
            'avg_comp_pct': 0.0,
            'time_stamp': time.time()
        }
        self._current_eval_pct_list_ = list()
        self.is_save_simtrace_enabled = rospy.get_param(
            'SIMTRACE_S3_BUCKET', None)
        run_phase_sink.register(self)
        # Create the agent specific directories needed for storing the metric files
        simtrace_dirname = os.path.dirname(
            IterationDataLocalFileNames.SIM_TRACE_TRAINING_LOCAL_FILE.value)
        if not os.path.exists(
                os.path.join(ITERATION_DATA_LOCAL_FILE_PATH, self._agent_name_,
                             simtrace_dirname)):
            os.makedirs(
                os.path.join(ITERATION_DATA_LOCAL_FILE_PATH, self._agent_name_,
                             simtrace_dirname))

    def reset(self):
        self._start_time_ = time.time()
        self._episode_reward_ = 0.0
        self._progress_ = 0.0

    def append_episode_metrics(self):
        self._episode_ += 1 if not self._is_eval_ else 0
        self._eval_trials_ += 1 if not self._is_eval_ else 0
        training_metric = dict()
        training_metric['reward_score'] = int(round(self._episode_reward_))
        training_metric['metric_time'] = int(round(time.time() * 1000))
        training_metric['start_time'] = int(round(self._start_time_ * 1000))
        training_metric['elapsed_time_in_milliseconds'] = \
            int(round((time.time() - self._start_time_) * 1000))
        training_metric['episode'] = int(self._episode_)
        training_metric['trial'] = int(self._eval_trials_)
        training_metric[
            'phase'] = 'evaluation' if self._is_eval_ else 'training'
        training_metric['completion_percentage'] = int(self._progress_)
        training_metric[
            'episode_status'] = EpisodeStatus.get_episode_status_label(
                self._episode_status)
        self._metrics_.append(training_metric)

    def upload_episode_metrics(self):
        write_metrics_to_s3(
            self._s3_dict_metrics_[MetricsS3Keys.METRICS_BUCKET.value],
            self._s3_dict_metrics_[MetricsS3Keys.METRICS_KEY.value],
            self._s3_dict_metrics_[MetricsS3Keys.REGION.value],
            {'metrics': self._metrics_})
        if self._is_eval_:
            self._current_eval_pct_list_.append(self._progress_)

    def upload_step_metrics(self, metrics):
        self._progress_ = metrics[StepMetrics.PROG.value]
        self._episode_status = metrics[StepMetrics.EPISODE_STATUS.value]
        self._episode_reward_ += metrics[StepMetrics.REWARD.value]
        #! TODO have this work with new sim trace class
        if not self._is_eval_:
            metrics[StepMetrics.EPISODE.value] = self._episode_
            self._episode_reward_ += metrics[StepMetrics.REWARD.value]
            StepMetrics.validate_dict(metrics)
            sim_trace_log(metrics)
            if self.is_save_simtrace_enabled:
                write_simtrace_to_local_file(
                    os.path.join(
                        os.path.join(ITERATION_DATA_LOCAL_FILE_PATH,
                                     self._agent_name_),
                        IterationDataLocalFileNames.
                        SIM_TRACE_TRAINING_LOCAL_FILE.value), metrics)

    def update(self, data):
        self._is_eval_ = data != RunPhase.TRAIN

        if not self._is_eval_ and self._use_model_picker:
            if self._eval_stats_dict_['chkpnt_name'] is None:
                self._eval_stats_dict_[
                    'chkpnt_name'] = self._checkpoint_state_.read().name

            self._eval_trials_ = 0
            mean_pct = statistics.mean(self._current_eval_pct_list_ if \
                                       self._current_eval_pct_list_ else [-1])
            LOGGER.info(
                'Number of evaluations: {} Evaluation progresses: {}'.format(
                    len(self._current_eval_pct_list_),
                    self._current_eval_pct_list_))
            LOGGER.info('Evaluation progresses mean: {}'.format(mean_pct))
            self._current_eval_pct_list_.clear()

            time_stamp = time.time()
            if mean_pct >= self._eval_stats_dict_['avg_comp_pct']:
                LOGGER.info('Current mean: {} >= Current best mean: {}'.format(
                    mean_pct, self._eval_stats_dict_['avg_comp_pct']))
                LOGGER.info(
                    'Updating the best checkpoint to "{}" from "{}".'.format(
                        self._eval_stats_dict_['chkpnt_name'],
                        self._best_chkpnt_stats['name']))
                self._eval_stats_dict_['avg_comp_pct'] = mean_pct
                self._best_chkpnt_stats = {
                    'name': self._eval_stats_dict_['chkpnt_name'],
                    'avg_comp_pct': mean_pct,
                    'time_stamp': time_stamp
                }
            last_chkpnt_stats = {
                'name': self._eval_stats_dict_['chkpnt_name'],
                'avg_comp_pct': mean_pct,
                'time_stamp': time_stamp
            }
            write_metrics_to_s3(
                self._s3_dict_model_[MetricsS3Keys.METRICS_BUCKET.value],
                self._s3_dict_model_[MetricsS3Keys.METRICS_KEY.value],
                self._s3_dict_model_[MetricsS3Keys.REGION.value], {
                    BEST_CHECKPOINT: self._best_chkpnt_stats,
                    LAST_CHECKPOINT: last_chkpnt_stats
                })
            # Update the checkpoint name to the new checkpoint being used for training that will
            # then be evaluated, note this class gets notfied when the system is put into a
            # training phase and assumes that a training phase only starts when a new check point
            # is avaialble
            self._eval_stats_dict_[
                'chkpnt_name'] = self._checkpoint_state_.read().name
Exemple #20
0
def rename_checkpoints(checkpoint_dir, agent_name):
    ''' Helper method that rename the specific checkpoint in the CheckpointStateFile
        to be scoped with agent_name
        checkpoint_dir - local checkpoint folder where the checkpoints and .checkpoint file is stored
        agent_name - name of the agent
    '''
    try:
        logger.info("Renaming checkpoint from checkpoint_dir: {} for agent: {}".format(checkpoint_dir, agent_name))
        state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
        checkpoint_name = str(state_file.read())
        tf_checkpoint_file = os.path.join(checkpoint_dir, "checkpoint")
        with open(tf_checkpoint_file, "w") as outfile:
            outfile.write("model_checkpoint_path: \"{}\"".format(checkpoint_name))

        with tf.Session() as sess:
            for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
                # Load the variable
                var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)
                new_name = var_name
                # Set the new name
                # Replace agent/ or agent_#/ with {agent_name}/
                new_name = re.sub('agent/|agent_\d+/', '{}/'.format(agent_name), new_name)
                # Rename the variable
                var = tf.Variable(var, name=new_name)
            saver = tf.train.Saver()
            sess.run(tf.global_variables_initializer())
            renamed_checkpoint_path = os.path.join(TEMP_RENAME_FOLDER, checkpoint_name)
            logger.info('Saving updated checkpoint to {}'.format(renamed_checkpoint_path))
            saver.save(sess, renamed_checkpoint_path)
        # Remove the tensorflow 'checkpoint' file
        os.remove(tf_checkpoint_file)
        # Remove the old checkpoint from the checkpoint dir
        for file_name in os.listdir(checkpoint_dir):
            if checkpoint_name in file_name:
                os.remove(os.path.join(checkpoint_dir, file_name))
        # Copy the new checkpoint with renamed variable to the checkpoint dir
        for file_name in os.listdir(TEMP_RENAME_FOLDER):
            full_file_name = os.path.join(os.path.abspath(TEMP_RENAME_FOLDER), file_name)
            if os.path.isfile(full_file_name) and file_name != "checkpoint":
                shutil.copy(full_file_name, checkpoint_dir)
        # Remove files from temp_rename_folder
        shutil.rmtree(TEMP_RENAME_FOLDER)
        tf.reset_default_graph()
    # If either of the checkpoint files (index, meta or data) not found
    except tf.errors.NotFoundError as err:
        log_and_exit("No checkpoint found: {}".format(err),
                     SIMAPP_SIMULATION_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_400)
    # Thrown when user modifies model, checkpoints get corrupted/truncated
    except tf.errors.DataLossError as err:
        log_and_exit("User modified ckpt, unrecoverable dataloss or corruption: {}"
                     .format(err),
                     SIMAPP_SIMULATION_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_400)
    except ValueError as err:
        if utils.is_error_bad_ckpnt(err):
            log_and_exit("Couldn't find 'checkpoint' file or checkpoints in given \
                            directory ./checkpoint: {}".format(err),
                         SIMAPP_SIMULATION_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_400)
        else:
            log_and_exit("ValueError in rename checkpoint: {}".format(err),
                         SIMAPP_SIMULATION_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
    except Exception as ex:
        log_and_exit("Exception in rename checkpoint: {}".format(ex),
                     SIMAPP_SIMULATION_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_500)
Exemple #21
0
class TrainingMetrics(MetricsInterface, ObserverInterface, AbstractTracker):
    '''This class is responsible for uploading training metrics to s3'''
    def __init__(self, agent_name, s3_dict_metrics, s3_dict_model, ckpnt_dir, run_phase_sink, use_model_picker=True):
        '''s3_dict_metrics - Dictionary containing the required s3 info for the metrics
                             bucket with keys specified by MetricsS3Keys
           s3_dict_model - Dictionary containing the required s3 info for the model
                           bucket, which is where the best model info will be saved with
                           keys specified by MetricsS3Keys
           ckpnt_dir - Directory where the current checkpont is to be stored
           run_phase_sink - Sink to recieve notification of a change in run phase
           use_model_picker - Flag to whether to use model picker or not.
        '''
        self._agent_name_ = agent_name
        self._s3_dict_metrics_ = s3_dict_metrics
        self._s3_dict_model_ = s3_dict_model
        self._start_time_ = time.time()
        self._episode_ = 0
        self._episode_reward_ = 0.0
        self._progress_ = 0.0
        self._episode_status = ''
        self._metrics_ = list()
        self._is_eval_ = True
        self._eval_trials_ = 0
        self._checkpoint_state_ = CheckpointStateFile(ckpnt_dir)
        self._use_model_picker = use_model_picker
        self._eval_stats_dict_ = {'chkpnt_name': None, 'avg_comp_pct': -1.0}
        self._best_chkpnt_stats = {'name': None, 'avg_comp_pct': -1.0, 'time_stamp': time.time()}
        self._current_eval_pct_list_ = list()
        self.is_save_simtrace_enabled = rospy.get_param('SIMTRACE_S3_BUCKET', None)
        self.track_data = TrackData.get_instance()
        run_phase_sink.register(self)
        # Create the agent specific directories needed for storing the metric files
        simtrace_dirname = os.path.dirname(IterationDataLocalFileNames.SIM_TRACE_TRAINING_LOCAL_FILE.value)
        if not os.path.exists(os.path.join(ITERATION_DATA_LOCAL_FILE_PATH, self._agent_name_, simtrace_dirname)):
            os.makedirs(os.path.join(ITERATION_DATA_LOCAL_FILE_PATH, self._agent_name_, simtrace_dirname))
        self._current_sim_time = 0
        rospy.Service("/{}/{}".format(self._agent_name_, "mp4_video_metrics"), VideoMetricsSrv,
                      self._handle_get_video_metrics)
        self._video_metrics = Mp4VideoMetrics.get_empty_dict()
        AbstractTracker.__init__(self, TrackerPriority.HIGH)

    def update_tracker(self, delta_time, sim_time):
        """
        Callback when sim time is updated

        Args:
            delta_time (float): time diff from last call
            sim_time (Clock): simulation time
        """
        self._current_sim_time = sim_time.clock.secs + 1.e-9 * sim_time.clock.nsecs

    def reset(self):
        self._start_time_ = self._current_sim_time
        self._episode_reward_ = 0.0
        self._progress_ = 0.0

    def append_episode_metrics(self):
        self._episode_ += 1 if not self._is_eval_ else 0
        self._eval_trials_ += 1 if not self._is_eval_ else 0
        training_metric = dict()
        training_metric['reward_score'] = int(round(self._episode_reward_))
        training_metric['metric_time'] = int(round(self._current_sim_time * 1000))
        training_metric['start_time'] = int(round(self._start_time_ * 1000))
        training_metric['elapsed_time_in_milliseconds'] = \
            int(round((self._current_sim_time - self._start_time_) * 1000))
        training_metric['episode'] = int(self._episode_)
        training_metric['trial'] = int(self._eval_trials_)
        training_metric['phase'] = 'evaluation' if self._is_eval_ else 'training'
        training_metric['completion_percentage'] = int(self._progress_)
        training_metric['episode_status'] = EpisodeStatus.get_episode_status_label(self._episode_status)
        self._metrics_.append(training_metric)

    def upload_episode_metrics(self):
        write_metrics_to_s3(self._s3_dict_metrics_[MetricsS3Keys.METRICS_BUCKET.value],
                            self._s3_dict_metrics_[MetricsS3Keys.METRICS_KEY.value],
                            self._s3_dict_metrics_[MetricsS3Keys.REGION.value],
                            {'metrics': self._metrics_},
                            self._s3_dict_metrics_[MetricsS3Keys.ENDPOINT_URL.value])
        if self._is_eval_:
            self._current_eval_pct_list_.append(self._progress_)

    def upload_step_metrics(self, metrics):
        self._progress_ = metrics[StepMetrics.PROG.value]
        self._episode_status = metrics[StepMetrics.EPISODE_STATUS.value]
        # Community fix to have reward for evaluation runs during training
        self._episode_reward_ += metrics[StepMetrics.REWARD.value]
        if not self._is_eval_:
            metrics[StepMetrics.EPISODE.value] = self._episode_
            StepMetrics.validate_dict(metrics)
            sim_trace_log(metrics)
            if self.is_save_simtrace_enabled:
                write_simtrace_to_local_file(
                    os.path.join(os.path.join(ITERATION_DATA_LOCAL_FILE_PATH, self._agent_name_),
                                 IterationDataLocalFileNames.SIM_TRACE_TRAINING_LOCAL_FILE.value),
                    metrics)
        self._update_mp4_video_metrics(metrics)

    def update(self, data):
        self._is_eval_ = data != RunPhase.TRAIN

        if not self._is_eval_ and self._use_model_picker:
            if self._eval_stats_dict_['chkpnt_name'] is None:
                self._eval_stats_dict_['chkpnt_name'] = self._checkpoint_state_.read().name

            self._eval_trials_ = 0
            mean_pct = statistics.mean(self._current_eval_pct_list_ if \
                                       self._current_eval_pct_list_ else [0.0])
            LOGGER.info('Number of evaluations: {} Evaluation progresses: {}'.format(len(self._current_eval_pct_list_),
                                                                                     self._current_eval_pct_list_))
            LOGGER.info('Evaluation progresses mean: {}'.format(mean_pct))
            self._current_eval_pct_list_.clear()

            time_stamp = self._current_sim_time
            if mean_pct >= self._eval_stats_dict_['avg_comp_pct']:
                LOGGER.info('Current mean: {} >= Current best mean: {}'.format(mean_pct,
                                                                               self._eval_stats_dict_['avg_comp_pct']))
                LOGGER.info('Updating the best checkpoint to "{}" from "{}".'.format(self._eval_stats_dict_['chkpnt_name'],
                                                                                     self._best_chkpnt_stats['name']))
                self._eval_stats_dict_['avg_comp_pct'] = mean_pct
                self._best_chkpnt_stats = {'name': self._eval_stats_dict_['chkpnt_name'],
                                           'avg_comp_pct': mean_pct,
                                           'time_stamp': time_stamp}
            last_chkpnt_stats = {'name': self._eval_stats_dict_['chkpnt_name'],
                                 'avg_comp_pct': mean_pct,
                                 'time_stamp': time_stamp}
            write_metrics_to_s3(self._s3_dict_model_[MetricsS3Keys.METRICS_BUCKET.value],
                                self._s3_dict_model_[MetricsS3Keys.METRICS_KEY.value],
                                self._s3_dict_model_[MetricsS3Keys.REGION.value],
                                {BEST_CHECKPOINT: self._best_chkpnt_stats,
                                 LAST_CHECKPOINT: last_chkpnt_stats},
                                self._s3_dict_metrics_[MetricsS3Keys.ENDPOINT_URL.value])
            # Update the checkpoint name to the new checkpoint being used for training that will
            # then be evaluated, note this class gets notfied when the system is put into a
            # training phase and assumes that a training phase only starts when a new check point
            # is avaialble
            self._eval_stats_dict_['chkpnt_name'] = self._checkpoint_state_.read().name

    def _update_mp4_video_metrics(self, metrics):
        agent_x, agent_y = metrics[StepMetrics.X.value], metrics[StepMetrics.Y.value]
        self._video_metrics[Mp4VideoMetrics.LAP_COUNTER.value] = 0
        self._video_metrics[Mp4VideoMetrics.COMPLETION_PERCENTAGE.value] = self._progress_
        # For continuous race, MP4 video will display the total reset counter for the entire race
        # For non-continuous race, MP4 video will display reset counter per lap
        self._video_metrics[Mp4VideoMetrics.RESET_COUNTER.value] = 0

        self._video_metrics[Mp4VideoMetrics.THROTTLE.value] = 0
        self._video_metrics[Mp4VideoMetrics.STEERING.value] = 0
        self._video_metrics[Mp4VideoMetrics.BEST_LAP_TIME.value] = 0
        self._video_metrics[Mp4VideoMetrics.TOTAL_EVALUATION_TIME.value] = 0
        self._video_metrics[Mp4VideoMetrics.DONE.value] = metrics[StepMetrics.DONE.value]
        self._video_metrics[Mp4VideoMetrics.X.value] = agent_x
        self._video_metrics[Mp4VideoMetrics.Y.value] = agent_y

        object_poses = [pose for object_name, pose in self.track_data.object_poses.items()\
                        if not object_name.startswith('racecar')]
        object_locations = []
        for pose in object_poses:
            point = Point32()
            point.x, point.y, point.z = pose.position.x, pose.position.y, 0
            object_locations.append(point)
        self._video_metrics[Mp4VideoMetrics.OBJECT_LOCATIONS.value] = object_locations

    def _handle_get_video_metrics(self, req):
        return VideoMetricsSrvResponse(self._video_metrics[Mp4VideoMetrics.LAP_COUNTER.value],
                                       self._video_metrics[Mp4VideoMetrics.COMPLETION_PERCENTAGE.value],
                                       self._video_metrics[Mp4VideoMetrics.RESET_COUNTER.value],
                                       self._video_metrics[Mp4VideoMetrics.THROTTLE.value],
                                       self._video_metrics[Mp4VideoMetrics.STEERING.value],
                                       self._video_metrics[Mp4VideoMetrics.BEST_LAP_TIME.value],
                                       self._video_metrics[Mp4VideoMetrics.TOTAL_EVALUATION_TIME.value],
                                       self._video_metrics[Mp4VideoMetrics.DONE.value],
                                       self._video_metrics[Mp4VideoMetrics.X.value],
                                       self._video_metrics[Mp4VideoMetrics.Y.value],
                                       self._video_metrics[Mp4VideoMetrics.OBJECT_LOCATIONS.value])