コード例 #1
0
    def save_to_store(self):
        try:
            s3_client = self._get_client()

            if self.graph_manager:
                utils.write_frozen_graph(self.graph_manager, self.params.checkpoint_dir)

            # Delete any existing lock file
            s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(self.params.lock_file))

            # We take a lock by writing a lock file to the same location in S3
            s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                     Bucket=self.params.bucket,
                                     Key=self._get_s3_key(self.params.lock_file))

            # Start writing the model checkpoints to S3
            checkpoint_file = None
            for root, dirs, files in os.walk(self.params.checkpoint_dir):
                for filename in files:
                    # Skip the checkpoint file that has the latest checkpoint number
                    if filename == CHECKPOINT_METADATA_FILENAME:
                        checkpoint_file = (root, filename)
                        continue

                    # Upload all the other files from the checkpoint directory
                    abs_name = os.path.abspath(os.path.join(root, filename))
                    rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
                    s3_client.upload_file(Filename=abs_name,
                                          Bucket=self.params.bucket,
                                          Key=self._get_s3_key(rel_name))

            # After all the checkpoint files have been uploaded, we upload the version file.
            abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
            rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
            s3_client.upload_file(Filename=abs_name,
                                  Bucket=self.params.bucket,
                                  Key=self._get_s3_key(rel_name))

            # Release the lock by deleting the lock file from S3
            s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(self.params.lock_file))

            checkpoint = self._get_current_checkpoint()
            if checkpoint:
                checkpoint_number = self._get_checkpoint_number(checkpoint)
                checkpoint_number_to_delete = checkpoint_number - 4

                # List all the old checkpoint files that needs to be deleted
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(str(checkpoint_number_to_delete) + "_"))
                if "Contents" in response:
                    num_files = 0
                    for obj in response["Contents"]:
                        s3_client.delete_object(Bucket=self.params.bucket,
                                                Key=obj["Key"])
                        num_files += 1

                    print("Deleted %s model files from S3" % num_files)
                    return True
        except Exception as e:
            raise e
コード例 #2
0
    def save_to_store(self):
        try:
            s3_client = self._get_client()

            # Delete any existing lock file
            s3_client.delete_object(Bucket=self.params.bucket,
                                    Key=self._get_s3_key(
                                        self.params.lock_file))

            # We take a lock by writing a lock file to the same location in S3
            s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                     Bucket=self.params.bucket,
                                     Key=self._get_s3_key(
                                         self.params.lock_file))

            # Start writing the model checkpoints to S3
            checkpoint = self._get_current_checkpoint()
            if checkpoint:
                checkpoint_number = self._get_checkpoint_number(checkpoint)

            checkpoint_file = None
            for root, _, files in os.walk(self.params.checkpoint_dir):
                num_files_uploaded = 0
                for filename in files:
                    # Skip the checkpoint file that has the latest checkpoint number
                    if filename == CHECKPOINT_METADATA_FILENAME:
                        checkpoint_file = (root, filename)
                        continue

                    if not filename.startswith(str(checkpoint_number)):
                        continue

                    # Upload all the other files from the checkpoint directory
                    abs_name = os.path.abspath(os.path.join(root, filename))
                    rel_name = os.path.relpath(abs_name,
                                               self.params.checkpoint_dir)
                    s3_client.upload_file(Filename=abs_name,
                                          Bucket=self.params.bucket,
                                          Key=self._get_s3_key(rel_name))
                    num_files_uploaded += 1
            logger.info("Uploaded {} files for checkpoint {}".format(
                num_files_uploaded, checkpoint_number))

            # After all the checkpoint files have been uploaded, we upload the version file.
            abs_name = os.path.abspath(
                os.path.join(checkpoint_file[0], checkpoint_file[1]))
            rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
            s3_client.upload_file(Filename=abs_name,
                                  Bucket=self.params.bucket,
                                  Key=self._get_s3_key(rel_name))

            # Release the lock by deleting the lock file from S3
            s3_client.delete_object(Bucket=self.params.bucket,
                                    Key=self._get_s3_key(
                                        self.params.lock_file))

            # Upload the frozen graph which is used for deployment
            if self.graph_manager:
                utils.write_frozen_graph(self.graph_manager)
                # upload the model_<ID>.pb to S3. NOTE: there's no cleanup as we don't know the best checkpoint
                iteration_id = self.graph_manager.level_managers[0].agents[
                    'agent'].training_iteration
                frozen_graph_fpath = utils.SM_MODEL_OUTPUT_DIR + "/model.pb"
                frozen_graph_s3_name = "model_%s.pb" % iteration_id
                s3_client.upload_file(
                    Filename=frozen_graph_fpath,
                    Bucket=self.params.bucket,
                    Key=self._get_s3_key(frozen_graph_s3_name))
                logger.info("saved intermediate frozen graph: {}".format(
                    self._get_s3_key(frozen_graph_s3_name)))

            # Clean up old checkpoints
            checkpoint = self._get_current_checkpoint()
            if checkpoint:
                checkpoint_number = self._get_checkpoint_number(checkpoint)
                checkpoint_number_to_delete = checkpoint_number - 4

                # List all the old checkpoint files to be deleted
                response = s3_client.list_objects_v2(
                    Bucket=self.params.bucket,
                    Prefix=self._get_s3_key(
                        str(checkpoint_number_to_delete) + "_"))
                if "Contents" in response:
                    num_files = 0
                    for obj in response["Contents"]:
                        s3_client.delete_object(Bucket=self.params.bucket,
                                                Key=obj["Key"])
                        num_files += 1
        except botocore.exceptions.ClientError as e:
            utils.json_format_logger(
                "Unable to upload checkpoint to {}, {}".format(
                    self.params.bucket, e.response['Error']['Code']),
                **utils.build_user_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_400))
            utils.simapp_exit_gracefully()
        except Exception as e:
            utils.json_format_logger(
                "Unable to upload checkpoint to {}, {}".format(
                    self.params.bucket, e),
                **utils.build_system_error_dict(
                    utils.SIMAPP_S3_DATA_STORE_EXCEPTION,
                    utils.SIMAPP_EVENT_ERROR_CODE_500))
            utils.simapp_exit_gracefully()
コード例 #3
0
    def save_to_store(self):
        try:
            s3_client = self._get_client()

            # Delete any existing lock file
            s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(self.params.lock_file))

            # We take a lock by writing a lock file to the same location in S3
            s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                     Bucket=self.params.bucket,
                                     Key=self._get_s3_key(self.params.lock_file))

            # Start writing the model checkpoints to S3
            checkpoint = self._get_current_checkpoint()
            if checkpoint:
                checkpoint_number = self._get_checkpoint_number(checkpoint)

            checkpoint_file = None
            for root, dirs, files in os.walk(self.params.checkpoint_dir):
                num_files_uploaded = 0
                for filename in files:
                    # Skip the checkpoint file that has the latest checkpoint number
                    if filename == CHECKPOINT_METADATA_FILENAME:
                        checkpoint_file = (root, filename)
                        continue

                    if not filename.startswith(str(checkpoint_number)):
                        continue

                    # Upload all the other files from the checkpoint directory
                    abs_name = os.path.abspath(os.path.join(root, filename))
                    rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
                    s3_client.upload_file(Filename=abs_name,
                                          Bucket=self.params.bucket,
                                          Key=self._get_s3_key(rel_name))
                    num_files_uploaded += 1
            print("Uploaded %s files for checkpoint %s" % (num_files_uploaded, checkpoint_number))

            # After all the checkpoint files have been uploaded, we upload the version file.
            abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
            rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
            s3_client.upload_file(Filename=abs_name,
                                  Bucket=self.params.bucket,
                                  Key=self._get_s3_key(rel_name))

            # Release the lock by deleting the lock file from S3
            s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(self.params.lock_file))

            # Upload the frozen graph which is used for deployment
            if self.graph_manager:
                utils.write_frozen_graph(self.graph_manager)
                # upload the model_<ID>.pb to S3. NOTE: there's no cleanup as we don't know the best checkpoint
                iteration_id = self.graph_manager.level_managers[0].agents['agent'].training_iteration
                frozen_graph_fpath = utils.SM_MODEL_OUTPUT_DIR + "/model.pb"
                frozen_graph_s3_name = "model_%s.pb" % iteration_id
                s3_client.upload_file(Filename=frozen_graph_fpath,
                                  Bucket=self.params.bucket,
                                  Key=self._get_s3_key(frozen_graph_s3_name))
                print ("saved intermediate frozen graph: ", self._get_s3_key(frozen_graph_s3_name)) 

            print("Trying to clean up old checkpoints.")
            # Clean up old checkpoints
            checkpoint = self._get_current_checkpoint()
            if checkpoint:
                checkpoint_number = self._get_checkpoint_number(checkpoint)
                checkpoint_number_to_delete = checkpoint_number - 4

                # List all the old checkpoint files to be deleted
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(str(checkpoint_number_to_delete) + "_"))
                if "Contents" in response:
                    num_files = 0
                    for obj in response["Contents"]:
                        s3_client.delete_object(Bucket=self.params.bucket,
                                                Key=obj["Key"])
                        num_files += 1

                    print("Deleted %s old model files from S3" % num_files)
                else:
                    print("Cleanup was not required.")
        except Exception as e:
            raise e
コード例 #4
0
    def save_to_store(self):
        try:
            s3_client = self._get_client()

            # Delete any existing lock file
            s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(self.params.lock_file))

            # We take a lock by writing a lock file to the same location in S3
            s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                     Bucket=self.params.bucket,
                                     Key=self._get_s3_key(self.params.lock_file))

            # Start writing the model checkpoints to S3
            checkpoint = self._get_current_checkpoint()
            if checkpoint:
                checkpoint_number = self._get_checkpoint_number(checkpoint)

            checkpoint_file = None
            for root, dirs, files in os.walk(self.params.checkpoint_dir):
                num_files_uploaded = 0
                for filename in files:
                    # Skip the checkpoint file that has the latest checkpoint number
                    if filename == CHECKPOINT_METADATA_FILENAME:
                        checkpoint_file = (root, filename)
                        continue

                    if not filename.startswith(str(checkpoint_number) + "_"):
                        continue

                    # Upload all the other files from the checkpoint directory
                    abs_name = os.path.abspath(os.path.join(root, filename))
                    rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
                    s3_client.upload_file(Filename=abs_name,
                                          Bucket=self.params.bucket,
                                          Key=self._get_s3_key(rel_name))
                    num_files_uploaded += 1
            print("Uploaded %s files for checkpoint %s" % (num_files_uploaded, checkpoint_number))

            # After all the checkpoint files have been uploaded, we upload the version file.
            abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
            rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
            s3_client.upload_file(Filename=abs_name,
                                  Bucket=self.params.bucket,
                                  Key=self._get_s3_key(rel_name))

            # Release the lock by deleting the lock file from S3
            s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(self.params.lock_file))

            # Upload the frozen graph which is used for deployment
            if self.graph_manager:
                utils.write_frozen_graph(self.graph_manager)

            print("Trying to clean up old checkpoints.")
            # Clean up old checkpoints
            checkpoint = self._get_current_checkpoint()
            if checkpoint:
                checkpoint_number = self._get_checkpoint_number(checkpoint)
                checkpoint_number_to_delete = checkpoint_number - 4

                # List all the old checkpoint files to be deleted
                response = s3_client.list_objects_v2(Bucket=self.params.bucket,
                                                     Prefix=self._get_s3_key(str(checkpoint_number_to_delete) + "_"))
                if "Contents" in response:
                    num_files = 0
                    for obj in response["Contents"]:
                        s3_client.delete_object(Bucket=self.params.bucket,
                                                Key=obj["Key"])
                        num_files += 1

                    print("Deleted %s old model files from S3" % num_files)
                else:
                    print("Cleanup was not required.")
        except Exception as e:
            raise e