예제 #1
0
# limitations under the License.
#

from core.model import ModelWrapper
from maxfw.core import MAX_API, PredictAPI
from config import ERR_MSG

from flask import send_file, abort
from werkzeug.datastructures import FileStorage

import io
import numpy as np
import base64

# Set up parser for input data
input_parser = MAX_API.parser()
input_parser.add_argument('image',
                          type=FileStorage,
                          location='files',
                          required=True,
                          help="Black and white JPEG or PNG image to colorize")


class ModelPredictAPI(PredictAPI):

    model_wrapper = ModelWrapper()

    @MAX_API.doc('predict')
    @MAX_API.expect(input_parser)
    def post(self):
        """Make a prediction given input data"""
예제 #2
0
from core.model import ModelWrapper
from flask_restplus import fields
from werkzeug.datastructures import FileStorage
from maxfw.core import MAX_API, PredictAPI

# Set up parser for input data (http://flask-restplus.readthedocs.io/en/stable/parsing.html)
input_parser = MAX_API.parser()
input_parser.add_argument('image', type=FileStorage, location='files', required=True, help="An image file (RGB/HWC)")


label_prediction = MAX_API.model('LabelPrediction', {
    'label_id': fields.String(required=False, description='Class label identifier'),
    'label': fields.String(required=True, description='Class label'),
    'probability': fields.Float(required=True, description='Predicted probability for the class label')
})


predict_response = MAX_API.model('ModelPredictResponse', {
    'status': fields.String(required=True, description='Response status message'),
    'predictions': fields.List(fields.Nested(label_prediction), description='Predicted class labels and probabilities')
})


class ModelPredictAPI(PredictAPI):

    model_wrapper = ModelWrapper()

    @MAX_API.doc('predict')
    @MAX_API.expect(input_parser)
    @MAX_API.marshal_with(predict_response)
    def post(self):
예제 #3
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from maxfw.core import MAX_API, PredictAPI
from core.model import ModelWrapper

from flask import abort
from flask_restplus import fields
from werkzeug.datastructures import FileStorage

# Set up parser for image input data
image_parser = MAX_API.parser()
image_parser.add_argument('image',
                          type=FileStorage,
                          location='files',
                          required=True,
                          help="An image file")

label_prediction = MAX_API.model(
    'LabelPrediction', {
        'index':
        fields.String(required=False,
                      description='Labels ranked by highest probability'),
        'caption':
        fields.String(required=True, description='Caption generated by image'),
        'probability':
        fields.Float(required=True, description="Probability of the caption")
        fields.String(required=True, description='Class label'),
        'probability':
        fields.Float(required=True)
    })

predict_response = MAX_API.model(
    'ModelPredictResponse', {
        'status':
        fields.String(required=True, description='Response status message'),
        'predictions':
        fields.List(fields.Nested(label_prediction),
                    description='Predicted labels and probabilities')
    })

# set up parser for image input data
video_parser = MAX_API.parser()
video_parser.add_argument('video',
                          type=FileStorage,
                          location='files',
                          required=True,
                          help="MPEG-4 video file to run predictions on")


@MAX_API.route('/predict')
class ModelPredictAPI(PredictAPI):

    model_wrapper = ModelWrapper()

    @MAX_API.doc('predict')
    @MAX_API.expect(video_parser)
    @MAX_API.marshal_with(predict_response)