from bentoml import api, env, BentoService, artifacts
from bentoml.artifact import TensorflowSavedModelArtifact, PickleArtifact
from bentoml.handlers import JsonHandler

import numpy as np
from scipy.special import softmax, expit

from aispace.datasets.tokenizer import BertTokenizer
from aispace.utils.hparams import Hparams

logger = logging.getLogger(__name__)


@artifacts([
    TensorflowSavedModelArtifact('model'),
    PickleArtifact('tokenizer'),
    PickleArtifact("hparams"),
])
@env(auto_pip_dependencies=True)
class BertTextClassificationService(BentoService):
    def preprocessing(self, one_json):
        texts = one_json.get("text", '')
        if isinstance(texts, (list, tuple)):
            if len(texts) >= 2:
                encode = self.artifacts.tokenizer.encode(texts[0], texts[1])
            elif len(texts) == 1:
                encode = self.artifacts.tokenizer.encode(texts[0])
            else:
                return None, None, None
        elif isinstance(texts, str):
import bentoml
import tensorflow as tf
import numpy as np
from PIL import Image

from bentoml.artifact import (
    TensorflowSavedModelArtifact,
)
from bentoml.handlers import TensorflowTensorHandler


FASHION_MNIST_CLASSES = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']


@bentoml.env(pip_dependencies=['tensorflow', 'numpy', 'pillow'])
@bentoml.artifacts([TensorflowSavedModelArtifact('model')])
class FashionMnistTensorflow(bentoml.BentoService):

    @bentoml.api(TensorflowTensorHandler)
    def predict(self, inputs):
        outputs = self.artifacts.model.predict_image(inputs)
        output_classes = tf.math.argmax(outputs, axis=1)
        return [FASHION_MNIST_CLASSES[c] for c in output_classes]