def test_not_keras(self, model): msg_match = r"^model must be either a Keras Sequential or Functional model.*" with pytest.raises(TypeError, match=msg_match): model_validator.must_keras(model)
def create_standard_model_from_keras( self, obj, environment, model_api=None, name=None, desc=None, labels=None, attrs=None, lock_level=None, ): """Create a Standard Verta Model version from a TensorFlow-backend Keras model. .. versionadded:: 0.18.2 Parameters ---------- obj : `tf.keras.Sequential <https://keras.io/guides/sequential_model/>`__ or `functional API keras.Model <https://keras.io/guides/functional_api/>`__ Keras model. environment : :class:`~verta.environment.Python` pip and apt dependencies. model_api : :class:`~verta.utils.ModelAPI`, optional Model API specifying the model's expected input and output name : str, optional Name of the model version. If no name is provided, one will be generated. desc : str, optional Description of the model version. labels : list of str, optional Labels of the model version. attrs : dict of str to {None, bool, float, int, str}, optional Attributes of the model version. lock_level : :mod:`~verta.registry.lock`, default :class:`~verta.registry.lock.Open` Lock level to set when creating this model version. Returns ------- :class:`~verta.registry.entities.RegisteredModelVersion` Examples -------- .. code-block:: python from tensorflow import keras from verta.environment import Python inputs = keras.Input(shape=(3,)) x = keras.layers.Dense(2, activation="relu")(inputs) outputs = keras.layers.Dense(1, activation="sigmoid")(x) model = keras.Model(inputs=inputs, outputs=outputs) train(model, data) model_ver = reg_model.create_standard_model_from_keras( model, Python(["tensorflow"]), ) endpoint.update(model_ver, wait=True) endpoint.get_deployed_model().predict(input) """ model_validator.must_keras(obj) return self._create_standard_model_from_spec( model=obj, environment=environment, model_api=model_api, name=name, desc=desc, labels=labels, attrs=attrs, lock_level=lock_level, )
def test_keras(self, model): assert model_validator.must_keras(model)