Ejemplo n.º 1
0
 def download_file(self, s3_key, local_path):
     s3_client = self.get_client()
     try:
         s3_client.download_file(self.bucket, s3_key, local_path)
         return True
     except botocore.exceptions.ClientError as e:
         if e.response['Error']['Code'] == "404":
             logger.info(
                 "Exception [{}] occured on download file-{} from s3 bucket-{} key-{}"
                 .format(e.response['Error'], local_path, self.bucket,
                         s3_key))
             return False
         else:
             utils.json_format_logger(
                 "boto client exception error [{}] occured on download file-{} from s3 bucket-{} key-{}"
                 .format(e.response['Error'], local_path, self.bucket,
                         s3_key),
                 **utils.build_user_error_dict(
                     utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                     utils.SIMAPP_EVENT_ERROR_CODE_401))
             return False
     except Exception as e:
         utils.json_format_logger(
             "Exception [{}] occcured on download file-{} from s3 bucket-{} key-{}"
             .format(e, local_path, self.bucket, s3_key),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_401))
         return False
 def write_ip_config(self, ip):
     try:
         s3_client = self.get_client()
         data = {"IP": ip}
         json_blob = json.dumps(data)
         file_handle = io.BytesIO(json_blob.encode())
         file_handle_done = io.BytesIO(b'done')
         s3_client.upload_fileobj(file_handle, self.bucket, self.config_key)
         s3_client.upload_fileobj(file_handle_done, self.bucket,
                                  self.done_file_key)
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Write ip config failed to upload to {}, {}".format(
                 self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Write ip config failed to upload to {}, {}".format(
                 self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
 def download_file(self, s3_key, local_path):
     s3_client = self.get_client()
     try:
         s3_client.download_file(self.bucket, s3_key, local_path)
         return True
     except botocore.exceptions.ClientError as e:
         # It is possible that the file isn't there in which case we should return fasle and let the client decide the next action
         if e.response['Error']['Code'] == "404":
             return False
         else:
             utils.json_format_logger(
                 "Unable to download {} from {}: {}".format(
                     s3_key, self.bucket, e.response['Error']['Code']),
                 **utils.build_user_error_dict(
                     utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                     utils.SIMAPP_EVENT_ERROR_CODE_400))
             utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Unable to download {} from {}: {}".format(
                 s3_key, self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
 def store_ip(self, ip_address):
     try:
         s3_client = self._get_client()
         ip_data = {IP_KEY: ip_address}
         ip_data_json_blob = json.dumps(ip_data)
         ip_data_file_object = io.BytesIO(ip_data_json_blob.encode())
         ip_done_file_object = io.BytesIO(b'done')
         s3_client.upload_fileobj(ip_data_file_object, self.params.bucket,
                                  self.ip_data_key)
         s3_client.upload_fileobj(ip_done_file_object, self.params.bucket,
                                  self.ip_done_key)
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Unable to store ip to {}, {}".format(
                 self.params.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Unable to store ip to {}, {}".format(self.params.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
    def download_preset_if_present(self, local_path):
        try:
            s3_client = self._get_client()
            response = s3_client.list_objects(Bucket=self.params.bucket,
                                              Prefix=self.preset_data_key)

            # If we don't find a preset, return false
            if "Contents" not in response:
                return False

            success = s3_client.download_file(Bucket=self.params.bucket,
                                              Key=self.preset_data_key,
                                              Filename=local_path)
            return success
        except botocore.exceptions.ClientError as e:
            utils.json_format_logger(
                "Unable to download presets from {}, {}".format(
                    self.params.bucket, e.response['Error']['Code']),
                **utils.build_user_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_400))
            utils.simapp_exit_gracefully()
        except Exception as e:
            utils.json_format_logger(
                "Unable to download presets from {}, {}".format(
                    self.params.bucket, e),
                **utils.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_500))
            utils.simapp_exit_gracefully()
 def upload_model(self, checkpoint_dir):
     try:
         s3_client = self.get_client()
         num_files = 0
         for root, _, files in os.walk("./" + checkpoint_dir):
             for filename in files:
                 abs_name = os.path.abspath(os.path.join(root, filename))
                 s3_client.upload_file(
                     abs_name, self.bucket, "%s/%s/%s" %
                     (self.s3_prefix, checkpoint_dir, filename))
                 num_files += 1
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Model failed to upload to {}, {}".format(
                 self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Model failed to upload to {}, {}".format(self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
def download_customer_reward_function(s3_client, reward_file_s3_key):
    reward_function_local_path = os.path.join(CUSTOM_FILES_PATH,
                                              "customer_reward_function.py")
    success_reward_function_download = s3_client.download_file(
        s3_key=reward_file_s3_key, local_path=reward_function_local_path)

    if not success_reward_function_download:
        utils.json_format_logger(
            "Unable to download the reward function code.",
            **utils.build_user_error_dict(
                utils.SIMAPP_SIMULATION_WORKER_EXCEPTION,
                utils.SIMAPP_EVENT_ERROR_CODE_400))
        traceback.print_exc()
        utils.simapp_exit_gracefully()
 def upload_finished_file(self):
     try:
         s3_client = self._get_client()
         s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                  Bucket=self.params.bucket,
                                  Key=self._get_s3_key(SyncFiles.FINISHED.value))
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger("Unable to upload finished file to {}, {}"
                                  .format(self.params.bucket, e.response['Error']['Code']),
                                  **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger("Unable to upload finished file to {}, {}"
                                  .format(self.params.bucket, e),
                                  **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                  utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
 def upload_file(self, s3_key, local_path):
     s3_client = self.get_client()
     try:
         s3_client.upload_file(Filename=local_path,
                               Bucket=self.bucket,
                               Key=s3_key)
         return True
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Unable to upload {} to {}: {}".format(
                 s3_key, self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
     except Exception as e:
         utils.json_format_logger(
             "Unable to upload {} to {}: {}".format(s3_key, self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
     return False
    def _wait_for_ip_upload(self, timeout=600):
        s3_client = self.get_client()
        time_elapsed = 0

        while time_elapsed < timeout:
            try:
                response = s3_client.list_objects(Bucket=self.bucket,
                                                  Prefix=self.done_file_key)
            except botocore.exceptions.ClientError as e:
                utils.json_format_logger(
                    "Unable to access {}: {}".format(
                        self.bucket, e.response['Error']['Code']),
                    **utils.build_user_error_dict(
                        utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                        utils.SIMAPP_EVENT_ERROR_CODE_400))
                utils.simapp_exit_gracefully()
            except Exception as e:
                utils.json_format_logger(
                    "Unable to access {}: {}".format(self.bucket, e),
                    **utils.build_system_error_dict(
                        utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                        utils.SIMAPP_EVENT_ERROR_CODE_500))
                utils.simapp_exit_gracefully()

            if "Contents" in response:
                return
            time.sleep(1)
            time_elapsed += 1
            if time_elapsed % 5 == 0:
                logger.info(
                    "Waiting for SageMaker Redis server IP... Time elapsed: %s seconds"
                    % time_elapsed)

        utils.json_format_logger(
            "Timed out while attempting to retrieve the Redis IP",
            **utils.build_system_error_dict(
                utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                utils.SIMAPP_EVENT_ERROR_CODE_500))
        utils.simapp_exit_gracefully()
 def upload_hyperparameters(self, hyperparams_json):
     try:
         s3_client = self.get_client()
         file_handle = io.BytesIO(hyperparams_json.encode())
         s3_client.upload_fileobj(file_handle, self.bucket,
                                  self.hyperparameters_key)
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Hyperparameters failed to upload to {}, {}".format(
                 self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Hyperparameters failed to upload to {}, {}".format(
                 self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
 def get_ip(self):
     s3_client = self.get_client()
     self._wait_for_ip_upload()
     try:
         s3_client.download_file(self.bucket, self.config_key, 'ip.json')
         with open("ip.json") as f:
             ip = json.load(f)["IP"]
         return ip
     except botocore.exceptions.ClientError as e:
         utils.json_format_logger(
             "Unable to retrieve redis ip from {}: {}".format(
                 self.bucket, e.response['Error']['Code']),
             **utils.build_user_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_400))
         utils.simapp_exit_gracefully()
     except Exception as e:
         utils.json_format_logger(
             "Unable to retrieve redis ip from {}: {}".format(
                 self.bucket, e),
             **utils.build_system_error_dict(
                 utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                 utils.SIMAPP_EVENT_ERROR_CODE_500))
         utils.simapp_exit_gracefully()
    def load_from_store(self, expected_checkpoint_number=-1):
        try:
            if not os.path.exists(self.params.checkpoint_dir):
                os.makedirs(self.params.checkpoint_dir)

            while True:
                s3_client = self._get_client()
                state_file = CheckpointStateFile(os.path.abspath(self.params.checkpoint_dir))

                # wait until lock is removed
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(SyncFiles.LOCKFILE.value))
                if "Contents" not in response:
                    try:
                        # fetch checkpoint state file from S3
                        s3_client.download_file(Bucket=self.params.bucket,
                                                Key=self._get_s3_key(state_file.filename),
                                                Filename=state_file.path)
                    except Exception as e:
                        time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                        continue
                else:
                    time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                    continue

                # check if there's a Finished file
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(SyncFiles.FINISHED.value))
                if "Contents" in response:
                    try:
                        finished_file_path = os.path.abspath(os.path.join(self.params.checkpoint_dir,
                                                                          SyncFiles.FINISHED.value))
                        s3_client.download_file(Bucket=self.params.bucket,
                                                Key=self._get_s3_key(SyncFiles.FINISHED.value),
                                                Filename=finished_file_path)
                    except Exception as e:
                        pass

                # check if there's a Ready file
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(SyncFiles.TRAINER_READY.value))
                if "Contents" in response:
                    try:
                        ready_file_path = os.path.abspath(os.path.join(self.params.checkpoint_dir,
                                                                       SyncFiles.TRAINER_READY.value))
                        s3_client.download_file(Bucket=self.params.bucket,
                                                Key=self._get_s3_key(SyncFiles.TRAINER_READY.value),
                                                Filename=ready_file_path)
                    except Exception as e:
                        pass

                checkpoint_state = state_file.read()
                if checkpoint_state is not None:

                    # if we get a checkpoint that is older that the expected checkpoint, we wait for
                    #  the new checkpoint to arrive.
                    if checkpoint_state.num < expected_checkpoint_number:
                        time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
                        continue

                    response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                         Prefix=self._get_s3_key(""))
                    if "Contents" in response:
                        # Check to see if the desired checkpoint is in the bucket
                        has_chkpnt = any(list(map(lambda obj: os.path.split(obj['Key'])[1].\
                                                              startswith(checkpoint_state.name),
                                                  response['Contents'])))
                        for obj in response["Contents"]:
                            full_key_prefix = os.path.normpath(self.key_prefix) + "/"
                            filename = os.path.abspath(os.path.join(self.params.checkpoint_dir,
                                                                    obj["Key"].\
                                                                    replace(full_key_prefix, "")))
                            dirname, basename = os.path.split(filename)
                            # Download all the checkpoints but not the frozen models since they
                            # are not necessary
                            _, file_extension = os.path.splitext(obj["Key"])
                            if file_extension != '.pb' \
                            and (basename.startswith(checkpoint_state.name) or not has_chkpnt):
                                if not os.path.exists(dirname):
                                    os.makedirs(dirname)
                                s3_client.download_file(Bucket=self.params.bucket,
                                                        Key=obj["Key"],
                                                        Filename=filename)
                        # Change the coach checkpoint file to point to the latest available checkpoint,
                        # also log that we are changing the checkpoint.
                        if not has_chkpnt:
                            all_ckpnts = _filter_checkpoint_files(os.listdir(self.params.checkpoint_dir))
                            if all_ckpnts:
                                logger.info("%s not in s3 bucket, downloading all checkpoints \
                                            and using %s", checkpoint_state.name, all_ckpnts[-1])
                                state_file.write(all_ckpnts[-1])
                            else:
                                utils.json_format_logger("No checkpoint files found in {}".format(self.params.bucket),
                                                         **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                                       utils.SIMAPP_EVENT_ERROR_CODE_400))
                                utils.simapp_exit_gracefully()
                return True

        except botocore.exceptions.ClientError as e:
            utils.json_format_logger("Unable to download checkpoint from {}, {}"
                                     .format(self.params.bucket, e.response['Error']['Code']),
                                     **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                   utils.SIMAPP_EVENT_ERROR_CODE_400))
            utils.simapp_exit_gracefully()
        except Exception as e:
            utils.json_format_logger("Unable to download checkpoint from {}, {}"
                                     .format(self.params.bucket, e),
                                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                     utils.SIMAPP_EVENT_ERROR_CODE_500))
            utils.simapp_exit_gracefully()
    def save_to_store(self):
        try:
            s3_client = self._get_client()

            # Delete any existing lock file
            s3_client.delete_object(Bucket=self.params.bucket,
                                    Key=self._get_s3_key(
                                        self.params.lock_file))

            # We take a lock by writing a lock file to the same location in S3
            s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                     Bucket=self.params.bucket,
                                     Key=self._get_s3_key(
                                         self.params.lock_file))

            # Start writing the model checkpoints to S3
            checkpoint = self._get_current_checkpoint()
            if checkpoint:
                checkpoint_number = self._get_checkpoint_number(checkpoint)

            checkpoint_file = None
            for root, _, files in os.walk(self.params.checkpoint_dir):
                num_files_uploaded = 0
                for filename in files:
                    # Skip the checkpoint file that has the latest checkpoint number
                    if filename == CHECKPOINT_METADATA_FILENAME:
                        checkpoint_file = (root, filename)
                        continue

                    if not filename.startswith(str(checkpoint_number)):
                        continue

                    # Upload all the other files from the checkpoint directory
                    abs_name = os.path.abspath(os.path.join(root, filename))
                    rel_name = os.path.relpath(abs_name,
                                               self.params.checkpoint_dir)
                    s3_client.upload_file(Filename=abs_name,
                                          Bucket=self.params.bucket,
                                          Key=self._get_s3_key(rel_name))
                    num_files_uploaded += 1
            logger.info("Uploaded {} files for checkpoint {}".format(
                num_files_uploaded, checkpoint_number))

            # After all the checkpoint files have been uploaded, we upload the version file.
            abs_name = os.path.abspath(
                os.path.join(checkpoint_file[0], checkpoint_file[1]))
            rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
            s3_client.upload_file(Filename=abs_name,
                                  Bucket=self.params.bucket,
                                  Key=self._get_s3_key(rel_name))

            # Release the lock by deleting the lock file from S3
            s3_client.delete_object(Bucket=self.params.bucket,
                                    Key=self._get_s3_key(
                                        self.params.lock_file))

            # Upload the frozen graph which is used for deployment
            if self.graph_manager:
                utils.write_frozen_graph(self.graph_manager)
                # upload the model_<ID>.pb to S3. NOTE: there's no cleanup as we don't know the best checkpoint
                iteration_id = self.graph_manager.level_managers[0].agents[
                    'agent'].training_iteration
                frozen_graph_fpath = utils.SM_MODEL_OUTPUT_DIR + "/model.pb"
                frozen_graph_s3_name = "model_%s.pb" % iteration_id
                s3_client.upload_file(
                    Filename=frozen_graph_fpath,
                    Bucket=self.params.bucket,
                    Key=self._get_s3_key(frozen_graph_s3_name))
                logger.info("saved intermediate frozen graph: {}".format(
                    self._get_s3_key(frozen_graph_s3_name)))

            # Clean up old checkpoints
            checkpoint = self._get_current_checkpoint()
            if checkpoint:
                checkpoint_number = self._get_checkpoint_number(checkpoint)
                checkpoint_number_to_delete = checkpoint_number - 4

                # List all the old checkpoint files to be deleted
                response = s3_client.list_objects_v2(
                    Bucket=self.params.bucket,
                    Prefix=self._get_s3_key(
                        str(checkpoint_number_to_delete) + "_"))
                if "Contents" in response:
                    num_files = 0
                    for obj in response["Contents"]:
                        s3_client.delete_object(Bucket=self.params.bucket,
                                                Key=obj["Key"])
                        num_files += 1
        except botocore.exceptions.ClientError as e:
            utils.json_format_logger(
                "Unable to upload checkpoint to {}, {}".format(
                    self.params.bucket, e.response['Error']['Code']),
                **utils.build_user_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_400))
            utils.simapp_exit_gracefully()
        except Exception as e:
            utils.json_format_logger(
                "Unable to upload checkpoint to {}, {}".format(
                    self.params.bucket, e),
                **utils.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_500))
            utils.simapp_exit_gracefully()
    def save_to_store(self):
        try:
            s3_client = self._get_client()
            checkpoint_dir = self.params.checkpoint_dir

            # remove lock file if it exists
            s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(SyncFiles.LOCKFILE.value))

            # acquire lock
            s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                     Bucket=self.params.bucket,
                                     Key=self._get_s3_key(SyncFiles.LOCKFILE.value))

            state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
            ckpt_state = None
            if state_file.exists():
                ckpt_state = state_file.read()
                checkpoint_file = None
                num_files_uploaded = 0
                for root, _, files in os.walk(checkpoint_dir):
                    for filename in files:
                        if filename == CheckpointStateFile.checkpoint_state_filename:
                            checkpoint_file = (root, filename)
                            continue
                        if filename.startswith(ckpt_state.name):
                            abs_name = os.path.abspath(os.path.join(root, filename))
                            rel_name = os.path.relpath(abs_name, checkpoint_dir)
                            s3_client.upload_file(Filename=abs_name,
                                                  Bucket=self.params.bucket,
                                                  Key=self._get_s3_key(rel_name))
                            num_files_uploaded += 1
                logger.info("Uploaded {} files for checkpoint {}".format(num_files_uploaded, ckpt_state.num))

                abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
                rel_name = os.path.relpath(abs_name, checkpoint_dir)
                s3_client.upload_file(Filename=abs_name,
                                      Bucket=self.params.bucket,
                                      Key=self._get_s3_key(rel_name))

            # upload Finished if present
            if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)):
                s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                         Bucket=self.params.bucket,
                                         Key=self._get_s3_key(SyncFiles.FINISHED.value))

            # upload Ready if present
            if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)):
                s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                         Bucket=self.params.bucket,
                                         Key=self._get_s3_key(SyncFiles.TRAINER_READY.value))

            # release lock
            s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(SyncFiles.LOCKFILE.value))

            # Upload the frozen graph which is used for deployment
            if self.graph_manager:
                self.write_frozen_graph(self.graph_manager)
                # upload the model_<ID>.pb to S3. NOTE: there's no cleanup as we don't know the best checkpoint
                for agent_params in self.graph_manager.agents_params:
                    iteration_id = self.graph_manager.level_managers[0].agents[agent_params.name].training_iteration
                    frozen_graph_fpath = os.path.join(SM_MODEL_OUTPUT_DIR, agent_params.name, "model.pb")
                    frozen_name = "model_{}.pb".format(iteration_id)
                    frozen_graph_s3_name = frozen_name if len(self.graph_manager.agents_params) == 1 \
                        else os.path.join(agent_params.name, frozen_name)
                    s3_client.upload_file(Filename=frozen_graph_fpath,
                                          Bucket=self.params.bucket,
                                          Key=self._get_s3_key(frozen_graph_s3_name))
                    logger.info("saved intermediate frozen graph: {}".format(self._get_s3_key(frozen_graph_s3_name)))

            # Clean up old checkpoints
            if ckpt_state:
                checkpoint_number_to_delete = ckpt_state.num - NUM_MODELS_TO_KEEP

                # List all the old checkpoint files to be deleted
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(""))
                if "Contents" in response:
                    for obj in response["Contents"]:
                        _, basename = os.path.split(obj["Key"])
                        if basename.startswith("{}_".format(checkpoint_number_to_delete)):
                            s3_client.delete_object(Bucket=self.params.bucket,
                                                    Key=obj["Key"])

        except botocore.exceptions.ClientError as e:
            utils.json_format_logger("Unable to upload checkpoint to {}, {}"
                                     .format(self.params.bucket, e.response['Error']['Code']),
                                     **utils.build_user_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                   utils.SIMAPP_EVENT_ERROR_CODE_400))
            utils.simapp_exit_gracefully()
        except Exception as e:
            utils.json_format_logger("Unable to upload checkpoint to {}, {}"
                                     .format(self.params.bucket, e),
                                     **utils.build_system_error_dict(utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                                                                     utils.SIMAPP_EVENT_ERROR_CODE_500))
            utils.simapp_exit_gracefully()
    def download_model(self, checkpoint_dir):
        s3_client = self.get_client()
        filename = "None"
        try:
            filename = os.path.abspath(
                os.path.join(checkpoint_dir, "checkpoint"))
            if not os.path.exists(checkpoint_dir):
                os.makedirs(checkpoint_dir)

            while True:
                response = s3_client.list_objects_v2(Bucket=self.bucket,
                                                     Prefix=self._get_s3_key(
                                                         self.lock_file))

                if "Contents" not in response:
                    # If no lock is found, try getting the checkpoint
                    try:
                        s3_client.download_file(
                            Bucket=self.bucket,
                            Key=self._get_s3_key("checkpoint"),
                            Filename=filename)
                    except Exception as e:
                        time.sleep(2)
                        continue
                else:
                    time.sleep(2)
                    continue

                ckpt = CheckpointState()
                if os.path.exists(filename):
                    contents = open(filename, 'r').read()
                    text_format.Merge(contents, ckpt)
                    rel_path = ckpt.model_checkpoint_path
                    checkpoint = int(rel_path.split('_Step')[0])

                    response = s3_client.list_objects_v2(
                        Bucket=self.bucket, Prefix=self._get_s3_key(rel_path))
                    if "Contents" in response:
                        num_files = 0
                        for obj in response["Contents"]:
                            filename = os.path.abspath(
                                os.path.join(
                                    checkpoint_dir, obj["Key"].replace(
                                        self.model_checkpoints_prefix, "")))
                            s3_client.download_file(Bucket=self.bucket,
                                                    Key=obj["Key"],
                                                    Filename=filename)
                            num_files += 1
                        return
        except botocore.exceptions.ClientError as e:
            utils.json_format_logger(
                "Unable to download model {} from {}: {}".format(
                    filename, self.bucket, e.response['Error']['Code']),
                **utils.build_user_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_400))
            utils.simapp_exit_gracefully()
        except Exception as e:
            utils.json_format_logger(
                "Unable to download model {} from {}: {}".format(
                    filename, self.bucket, e),
                **utils.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_500))
            utils.simapp_exit_gracefully()
    def load_from_store(self, expected_checkpoint_number=-1):
        try:
            filename = os.path.abspath(
                os.path.join(self.params.checkpoint_dir,
                             CHECKPOINT_METADATA_FILENAME))
            if not os.path.exists(self.params.checkpoint_dir):
                os.makedirs(self.params.checkpoint_dir)

            # CT: remove all prior checkpoint files to save local storage space
            if len(os.listdir(self.params.checkpoint_dir)) > 0:
                files_to_remove = os.listdir(self.params.checkpoint_dir)
                utils.json_format_logger('Removing %d old checkpoint files' %
                                         len(files_to_remove))
                for f in files_to_remove:
                    try:
                        os.unlink(
                            os.path.abspath(
                                os.path.join(self.params.checkpoint_dir, f)))
                    except Exception as e:
                        # swallow errors
                        pass

            while True:
                s3_client = self._get_client()
                # Check if there's a finished file
                response = s3_client.list_objects_v2(
                    Bucket=self.params.bucket,
                    Prefix=self._get_s3_key(SyncFiles.FINISHED.value))
                if "Contents" in response:
                    finished_file_path = os.path.abspath(
                        os.path.join(self.params.checkpoint_dir,
                                     SyncFiles.FINISHED.value))
                    s3_client.download_file(Bucket=self.params.bucket,
                                            Key=self._get_s3_key(
                                                SyncFiles.FINISHED.value),
                                            Filename=finished_file_path)
                    return False

                # Check if there's a lock file
                response = s3_client.list_objects_v2(
                    Bucket=self.params.bucket,
                    Prefix=self._get_s3_key(self.params.lock_file))

                if "Contents" not in response:
                    try:
                        # If no lock is found, try getting the checkpoint
                        s3_client.download_file(
                            Bucket=self.params.bucket,
                            Key=self._get_s3_key(CHECKPOINT_METADATA_FILENAME),
                            Filename=filename)
                    except Exception as e:
                        utils.json_format_logger(
                            "Sleeping {} seconds while lock file is present".
                            format(
                                SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                            ))
                        time.sleep(
                            SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                        )
                        continue
                else:
                    utils.json_format_logger(
                        "Sleeping {} seconds while lock file is present".
                        format(
                            SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                        ))
                    time.sleep(
                        SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                    )
                    continue

                checkpoint = self._get_current_checkpoint()
                if checkpoint:
                    checkpoint_number = self._get_checkpoint_number(checkpoint)

                    # if we get a checkpoint that is older that the expected checkpoint, we wait for
                    #  the new checkpoint to arrive.
                    if checkpoint_number < expected_checkpoint_number:
                        time.sleep(
                            SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND
                        )
                        continue

                    # Found a checkpoint to be downloaded
                    response = s3_client.list_objects_v2(
                        Bucket=self.params.bucket,
                        Prefix=self._get_s3_key(
                            checkpoint.model_checkpoint_path))
                    if "Contents" in response:
                        num_files = 0
                        for obj in response["Contents"]:
                            # Get the local filename of the checkpoint file
                            full_key_prefix = os.path.normpath(
                                self.key_prefix) + "/"
                            filename = os.path.abspath(
                                os.path.join(
                                    self.params.checkpoint_dir,
                                    obj["Key"].replace(full_key_prefix, "")))
                            s3_client.download_file(Bucket=self.params.bucket,
                                                    Key=obj["Key"],
                                                    Filename=filename)
                            num_files += 1
                        return True

        except botocore.exceptions.ClientError as e:
            utils.json_format_logger(
                "Unable to download checkpoint from {}, {}".format(
                    self.params.bucket, e.response['Error']['Code']),
                **utils.build_user_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_400))
            utils.simapp_exit_gracefully()
        except Exception as e:
            utils.json_format_logger(
                "Unable to download checkpoint from {}, {}".format(
                    self.params.bucket, e),
                **utils.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_500))
            utils.simapp_exit_gracefully()
    def infer_reward_state(self, steering_angle, speed):
        try:
            self.set_next_state()
        except Exception as ex:
            utils.json_format_logger(
                "Unable to retrieve image from queue: {}".format(ex),
                **utils.build_system_error_dict(
                    utils.SIMAPP_ENVIRONMENT_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_500))

        # Read model state from Gazebo
        model_state = self.get_model_state('racecar', '')
        model_orientation = Rotation.from_quat([
            model_state.pose.orientation.x, model_state.pose.orientation.y,
            model_state.pose.orientation.z, model_state.pose.orientation.w
        ])
        model_location = np.array([
            model_state.pose.position.x,
            model_state.pose.position.y,
            model_state.pose.position.z]) + \
            model_orientation.apply(RELATIVE_POSITION_OF_FRONT_OF_CAR)
        model_point = Point(model_location[0], model_location[1])
        model_heading = model_orientation.as_euler('zyx')[0]

        # Read the wheel locations from Gazebo
        left_rear_wheel_state = self.get_link_state('racecar::left_rear_wheel',
                                                    '')
        left_front_wheel_state = self.get_link_state(
            'racecar::left_front_wheel', '')
        right_rear_wheel_state = self.get_link_state(
            'racecar::right_rear_wheel', '')
        right_front_wheel_state = self.get_link_state(
            'racecar::right_front_wheel', '')
        wheel_points = [
            Point(left_rear_wheel_state.link_state.pose.position.x,
                  left_rear_wheel_state.link_state.pose.position.y),
            Point(left_front_wheel_state.link_state.pose.position.x,
                  left_front_wheel_state.link_state.pose.position.y),
            Point(right_rear_wheel_state.link_state.pose.position.x,
                  right_rear_wheel_state.link_state.pose.position.y),
            Point(right_front_wheel_state.link_state.pose.position.x,
                  right_front_wheel_state.link_state.pose.position.y)
        ]

        # Project the current location onto the center line and find nearest points
        current_ndist = self.center_line.project(model_point, normalized=True)
        prev_index, next_index = self.find_prev_next_waypoints(current_ndist)
        distance_from_prev = model_point.distance(
            Point(self.center_line.coords[prev_index]))
        distance_from_next = model_point.distance(
            Point(self.center_line.coords[next_index]))
        closest_waypoint_index = (
            prev_index, next_index)[distance_from_next < distance_from_prev]

        # Compute distance from center and road width
        nearest_point_center = self.center_line.interpolate(current_ndist,
                                                            normalized=True)
        nearest_point_inner = self.inner_border.interpolate(
            self.inner_border.project(nearest_point_center))
        nearest_point_outer = self.outer_border.interpolate(
            self.outer_border.project(nearest_point_center))
        distance_from_center = nearest_point_center.distance(model_point)
        distance_from_inner = nearest_point_inner.distance(model_point)
        distance_from_outer = nearest_point_outer.distance(model_point)
        track_width = nearest_point_inner.distance(nearest_point_outer)
        is_left_of_center = (distance_from_outer < distance_from_inner) if self.reverse_dir \
            else (distance_from_inner < distance_from_outer)

        # Convert current progress to be [0,100] starting at the initial waypoint
        if self.reverse_dir:
            current_progress = self.start_ndist - current_ndist
        else:
            current_progress = current_ndist - self.start_ndist
        if current_progress < 0.0: current_progress = current_progress + 1.0
        current_progress = 100 * current_progress
        if current_progress < self.prev_progress:
            # Either: (1) we wrapped around and have finished the track,
            delta1 = current_progress + 100 - self.prev_progress
            # or (2) for some reason the car went backwards (this should be rare)
            delta2 = self.prev_progress - current_progress
            current_progress = (self.prev_progress, 100)[delta1 < delta2]

        # Car is off track if all wheels are outside the borders
        wheel_on_track = [self.road_poly.contains(p) for p in wheel_points]
        all_wheels_on_track = all(wheel_on_track)
        any_wheels_on_track = any(wheel_on_track)

        # Compute the reward
        if any_wheels_on_track:
            done = False
            params = {
                'all_wheels_on_track': all_wheels_on_track,
                'x': model_point.x,
                'y': model_point.y,
                'heading': model_heading * 180.0 / math.pi,
                'distance_from_center': distance_from_center,
                'progress': current_progress,
                'steps': self.steps,
                'speed': speed,
                'steering_angle': steering_angle * 180.0 / math.pi,
                'track_width': track_width,
                'waypoints': list(self.center_line.coords),
                'closest_waypoints': [prev_index, next_index],
                'is_left_of_center': is_left_of_center,
                'is_reversed': self.reverse_dir
            }
            try:
                reward = float(self.reward_function(params))
            except Exception as e:
                utils.json_format_logger(
                    "Exception {} in customer reward function. Job failed!".
                    format(e),
                    **utils.build_user_error_dict(
                        utils.SIMAPP_SIMULATION_WORKER_EXCEPTION,
                        utils.SIMAPP_EVENT_ERROR_CODE_400))
                traceback.print_exc()
                sys.exit(1)
        else:
            done = True
            reward = CRASHED

        # Reset if the car position hasn't changed in the last 2 steps
        prev_pnt_dist = min(model_point.distance(self.prev_point),
                            model_point.distance(self.prev_point_2))

        if prev_pnt_dist <= 0.0001 and self.steps % NUM_STEPS_TO_CHECK_STUCK == 0:
            done = True
            reward = CRASHED  # stuck

        # Simulation jobs are done when progress reaches 100
        if current_progress >= 100:
            done = True

        # Keep data from the previous step around
        self.prev_point_2 = self.prev_point
        self.prev_point = model_point
        self.prev_progress = current_progress

        # Set the reward and done flag
        self.reward = reward
        self.reward_in_episode += reward
        self.done = done

        # Trace logs to help us debug and visualize the training runs
        # btown TODO: This should be written to S3, not to CWL.
        logger.info(
            'SIM_TRACE_LOG:%d,%d,%.4f,%.4f,%.4f,%.2f,%.2f,%d,%.4f,%s,%s,%.4f,%d,%.2f,%s\n'
            %
            (self.episodes, self.steps, model_location[0], model_location[1],
             model_heading, self.steering_angle, self.speed, self.action_taken,
             self.reward, self.done, all_wheels_on_track, current_progress,
             closest_waypoint_index, self.track_length, time.time()))

        metrics = {
            "episodes": self.episodes,
            "steps": self.steps,
            "x": model_location[0],
            "y": model_location[1],
            "heading": model_heading,
            "steering_angle": self.steering_angle,
            "speed": self.speed,
            "action": self.action_taken,
            "reward": self.reward,
            "done": self.done,
            "all_wheels_on_track": all_wheels_on_track,
            "current_progress": current_progress,
            "closest_waypoint_index": closest_waypoint_index,
            "track_length": self.track_length,
            "time": time.time()
        }
        self.log_to_datadog(rewards)

        # Terminate this episode when ready
        if done and node_type == SIMULATION_WORKER:
            self.finish_episode(current_progress)