def patch(self, model_id, prediction_id):
        """
        Allow updating of links in model
        ---
        produces:
            - application/json
        parameters:
            - in: path
              name: model_id
              description: ID of the Model
              required: true
              type: integer
            - in: path
              name: prediction_id
              description: ID of the Prediction
              required: true
              type: integer
        responses:
            200:
                description: Prediction updated successfully
            404:
                description: Prediction not found to update
            500:
                description: Internal Server Error
        """
        try:
            updated_prediction = request.get_json()

            if updated_prediction is None:
                return err(400, "prediction must be json object"), 400

            prediction_id = PredictionService.patch(prediction_id, updated_prediction)

            return {
                "model_id": model_id,
                "prediction_id": prediction_id
            }, 200
        except NotFound:
            return err(404, "prediction not found"), 404
        except Exception as e:
            error_msg = f'Unhandled error: {str(e)}'
            current_app.logger.error(error_msg)
            return err(500, error_msg), 500
    def post(self, model_id, prediction_id):
        """
        Attach a raw model to a given predition
        ---
        produces:
            - application/json
        responses:
            200:
                description: ID of the prediction
            400:
                description: Invalid Request
            500:
                description: Internal Server Error
        """

        if CONFIG.EnvironmentConfig.ENVIRONMENT != "aws":
            return err(501, "stack must be in 'aws' mode to use this endpoint"), 501

        if CONFIG.EnvironmentConfig.ASSET_BUCKET is None:
            return err(501, "Not Configured"), 501

        modeltype = request.args.get('type', 'model')
        if modeltype not in ["model", "tfrecord", "checkpoint"]:
            return err(400, "Unsupported type param"), 400

        key = "models/{0}/prediction/{1}/{2}.zip".format(
            model_id,
            prediction_id,
            modeltype
        )

        try:
            boto3.client('s3').head_object(
                Bucket=CONFIG.EnvironmentConfig.ASSET_BUCKET,
                Key=key
            )
        except:
            files = list(request.files.keys())
            if len(files) == 0:
                return err(400, "Model not found in request"), 400

            model = request.files[files[0]]

            # Save the model to S3
            try:
                boto3.resource('s3').Bucket(CONFIG.EnvironmentConfig.ASSET_BUCKET).put_object(
                    Key=key,
                    Body=model.stream
                )
            except Exception as e:
                error_msg = f'S3 Upload Error: {str(e)}'
                current_app.logger.error(error_msg)
                return err(500, "Failed to upload model to S3"), 500

            if modeltype == "checkpoint":
                try:
                    PredictionService.patch(prediction_id, {
                        "checkpointLink": CONFIG.EnvironmentConfig.ASSET_BUCKET + '/' + key
                    })
                except Exception as e:
                    error_msg = f'SaveLink Error: {str(e)}'
                    current_app.logger.error(error_msg)
                    return err(500, "Failed to save checkpoint state to DB"), 500

            if modeltype == "tfrecord":
                try:
                    PredictionService.patch(prediction_id, {
                        "tfrecordLink": CONFIG.EnvironmentConfig.ASSET_BUCKET + '/' + key
                    })
                except Exception as e:
                    error_msg = f'SaveLink Error: {str(e)}'
                    current_app.logger.error(error_msg)
                    return err(500, "Failed to save checkpoint state to DB"), 500

            if modeltype == "model":
                # Save the model link to ensure UI shows upload success
                try:
                    PredictionService.patch(prediction_id, {
                        "modelLink": CONFIG.EnvironmentConfig.ASSET_BUCKET + '/' + key
                    })
                except Exception as e:
                    error_msg = f'SaveLink Error: {str(e)}'
                    current_app.logger.error(error_msg)
                    return err(500, "Failed to save model state to DB"), 500

                try:
                    batch = boto3.client(
                        service_name='batch',
                        region_name='us-east-1',
                        endpoint_url='https://batch.us-east-1.amazonaws.com'
                    )

                    # Submit to AWS Batch to convert to ECR image
                    batch.submit_job(
                        jobName=CONFIG.EnvironmentConfig.STACK + 'ecr-build',
                        jobQueue=CONFIG.EnvironmentConfig.STACK + '-queue',
                        jobDefinition=CONFIG.EnvironmentConfig.STACK + '-job',
                        containerOverrides={
                            'environment': [{
                                'name': 'MODEL',
                                'value': CONFIG.EnvironmentConfig.ASSET_BUCKET + '/' + key
                            }]
                        }
                    )
                except Exception as e:
                    error_msg = f'Batch Error: {str(e)}'
                    current_app.logger.error(error_msg)
                    return err(500, "Failed to start ECR build"), 500

            return { "status": "model uploaded" }, 200
        else:
            return err(400, "model exists"), 400
Esempio n. 3
0
    def post(self, project_id, prediction_id):
        """
        Attach a raw model to a given predition
        ---
        produces:
            - application/json
        responses:
            200:
                description: ID of the prediction
            400:
                description: Invalid Request
            500:
                description: Internal Server Error
        """

        if CONFIG.EnvironmentConfig.ENVIRONMENT != "aws":
            return err(501, "stack must be in 'aws' mode to use this endpoint"), 501

        if CONFIG.EnvironmentConfig.ASSET_BUCKET is None:
            return err(501, "Not Configured"), 501

        modeltype = request.args.get("type", "model")
        if modeltype not in ["model", "tfrecord", "checkpoint"]:
            return err(400, "Unsupported type param"), 400

        key = "models/{0}/prediction/{1}/{2}.zip".format(
            project_id, prediction_id, modeltype
        )

        try:
            boto3.client("s3").head_object(
                Bucket=CONFIG.EnvironmentConfig.ASSET_BUCKET, Key=key
            )
        except Exception:
            files = list(request.files.keys())
            if len(files) == 0:
                return err(400, "Model not found in request"), 400

            model = request.files[files[0]]

            # Save the model to S3
            try:
                boto3.resource("s3").Bucket(
                    CONFIG.EnvironmentConfig.ASSET_BUCKET
                ).put_object(Key=key, Body=model.stream)
            except Exception:
                current_app.logger.error(traceback.format_exc())

                return err(500, "Failed to upload model to S3"), 500

            if modeltype == "checkpoint":
                try:
                    PredictionService.patch(
                        prediction_id,
                        {
                            "checkpointLink": CONFIG.EnvironmentConfig.ASSET_BUCKET
                            + "/"
                            + key
                        },
                    )
                except Exception:
                    current_app.logger.error(traceback.format_exc())

                    return err(500, "Failed to save checkpoint state to DB"), 500

            if modeltype == "tfrecord":
                try:
                    PredictionService.patch(
                        prediction_id,
                        {
                            "tfrecordLink": CONFIG.EnvironmentConfig.ASSET_BUCKET
                            + "/"
                            + key
                        },
                    )
                except Exception:
                    current_app.logger.error(traceback.format_exc())

                    return err(500, "Failed to save checkpoint state to DB"), 500

            if modeltype == "model":
                # Save the model link to ensure UI shows upload success
                try:
                    PredictionService.patch(
                        prediction_id,
                        {
                            "modelLink": CONFIG.EnvironmentConfig.ASSET_BUCKET
                            + "/"
                            + key
                        },
                    )
                except Exception:
                    current_app.logger.error(traceback.format_exc())

                    return err(500, "Failed to save model state to DB"), 500

                try:
                    batch = boto3.client(
                        service_name="batch",
                        region_name="us-east-1",
                        endpoint_url="https://batch.us-east-1.amazonaws.com",
                    )

                    # Submit to AWS Batch to convert to ECR image
                    job = batch.submit_job(
                        jobName=CONFIG.EnvironmentConfig.STACK + "ecr-build",
                        jobQueue=CONFIG.EnvironmentConfig.STACK + "-queue",
                        jobDefinition=CONFIG.EnvironmentConfig.STACK + "-build-job",
                        containerOverrides={
                            "environment": [
                                {
                                    "name": "MODEL",
                                    "value": CONFIG.EnvironmentConfig.ASSET_BUCKET
                                    + "/"
                                    + key,
                                }
                            ]
                        },
                    )

                    TaskService.create(
                        {
                            "pred_id": prediction_id,
                            "type": "ecr",
                            "batch_id": job.get("jobId"),
                        }
                    )
                except Exception:
                    current_app.logger.error(traceback.format_exc())

                    return err(500, "Failed to start ECR build"), 500

            return {"status": "model uploaded"}, 200
        else:
            return err(400, "asset exists"), 400