def test_tensorflow_wrapper_keras_subclass_decorator_compile_args():
    import tensorflow as tf

    class UndecoratedModel(tf.keras.Model):
        def call(self, inputs):
            return inputs

    # Can't wrap an undecorated keras subclass model
    with pytest.raises(ValueError):
        TensorFlowWrapper(UndecoratedModel())

    @keras_subclass(
        "TestModel",
        X=numpy.array([0.0, 0.0]),
        Y=numpy.array([0.5]),
        input_shape=(2,),
        compile_args={"loss": "binary_crossentropy"},
    )
    class TestModel(tf.keras.Model):
        def call(self, inputs):
            return inputs

    model = TensorFlowWrapper(TestModel())
    model = model.from_bytes(model.to_bytes())

    assert model.shims[0]._model.loss == "binary_crossentropy"
    assert isinstance(model, Model)
def test_tensorflow_wrapper_keras_subclass_decorator_capture_args_kwargs(
    X, Y, input_size, n_classes, answer
):
    import tensorflow as tf

    @keras_subclass(
        "TestModel", X=numpy.array([0.0, 0.0]), Y=numpy.array([0.5]), input_shape=(2,)
    )
    class TestModel(tf.keras.Model):
        def __init__(self, custom=False, **kwargs):
            super().__init__(self)
            # This is to force the mode to pass the captured arguments
            # or fail.
            assert custom is True
            assert kwargs.get("other", None) is not None

        def call(self, inputs):
            return inputs

    # Can wrap an decorated keras subclass model
    model = TensorFlowWrapper(TestModel(True, other=1337))

    assert hasattr(model.shims[0]._model, "eg_args")
    args_kwargs = model.shims[0]._model.eg_args
    assert True in args_kwargs.args
    assert "other" in args_kwargs.kwargs

    # Raises an error if the args/kwargs is not serializable
    obj = {}
    obj["key"] = obj
    with pytest.raises(ValueError):
        TensorFlowWrapper(TestModel(True, other=obj))

    # Provides the same arguments when copying a capture model
    model = model.from_bytes(model.to_bytes())
def test_tensorflow_wrapper_serialize_model_subclass(
    X, Y, input_size, n_classes, answer
):
    import tensorflow as tf

    input_shape = (1, input_size)
    ops = get_current_ops()

    @keras_subclass(
        "foo.v1",
        X=ops.alloc2f(*input_shape),
        Y=to_categorical(ops.asarray1i([1]), n_classes=n_classes),
        input_shape=input_shape,
    )
    class CustomKerasModel(tf.keras.Model):
        def __init__(self, **kwargs):
            super(CustomKerasModel, self).__init__(**kwargs)
            self.in_dense = tf.keras.layers.Dense(
                12, name="in_dense", input_shape=input_shape
            )
            self.out_dense = tf.keras.layers.Dense(
                n_classes, name="out_dense", activation="softmax"
            )

        def call(self, inputs) -> tf.Tensor:
            x = self.in_dense(inputs)
            return self.out_dense(x)

    model = TensorFlowWrapper(CustomKerasModel())
    # Train the model to predict the right single answer
    optimizer = Adam()
    for i in range(50):
        guesses, backprop = model(X, is_train=True)
        d_guesses = (guesses - Y) / guesses.shape[0]
        backprop(d_guesses)
        model.finish_update(optimizer)
    predicted = model.predict(X).argmax()
    assert predicted == answer

    # Save then Load the model from bytes
    model.from_bytes(model.to_bytes())

    # The from_bytes model gets the same answer
    assert model.predict(X).argmax() == answer