Esempio n. 1
0
    def test_works_with_registered(self):
        class CustomClass(object):
            def value(self):
                return tf.convert_to_tensor(42.)

        tf.register_tensor_conversion_function(
            CustomClass, lambda value, **_: value.value())

        tf_utils.register_symbolic_tensor_type(CustomClass)

        if tf.executing_eagerly():
            self.assertFalse(
                tf_utils.is_symbolic_tensor(
                    tf.Variable(name='blah', initial_value=0.)))
            self.assertFalse(
                tf_utils.is_symbolic_tensor(tf.convert_to_tensor(0.)))
            self.assertFalse(
                tf_utils.is_symbolic_tensor(
                    tf.SparseTensor(indices=[[0, 0], [1, 2]],
                                    values=[1, 2],
                                    dense_shape=[3, 4])))
            self.assertFalse(tf_utils.is_symbolic_tensor(CustomClass()))
        else:
            self.assertTrue(
                tf_utils.is_symbolic_tensor(
                    tf.Variable(name='blah', initial_value=0.)))
            self.assertTrue(
                tf_utils.is_symbolic_tensor(tf.convert_to_tensor(0.)))
            self.assertTrue(
                tf_utils.is_symbolic_tensor(
                    tf.SparseTensor(indices=[[0, 0], [1, 2]],
                                    values=[1, 2],
                                    dense_shape=[3, 4])))
            self.assertTrue(tf_utils.is_symbolic_tensor(CustomClass()))
Esempio n. 2
0
    def test_enables_nontensor_plumbing(self):
        if tf.executing_eagerly():
            self.skipTest('`compile` functionality changed.')
        # Setup.

        class Foo(object):
            def __init__(self, input_):
                self._input = input_
                self.value = tf.convert_to_tensor([[42.]])

            @property
            def dtype(self):
                return self.value.dtype

        tf.register_tensor_conversion_function(
            Foo, lambda x, *args, **kwargs: x.value)
        tf_utils.register_symbolic_tensor_type(Foo)

        class PlumbingLayer(keras.layers.Lambda):
            def __init__(self, fn, **kwargs):
                def _fn(*fargs, **fkwargs):
                    d = fn(*fargs, **fkwargs)
                    x = tf.convert_to_tensor(d)
                    d.shape = x.shape
                    d.get_shape = x.get_shape
                    return d, x

                super(PlumbingLayer, self).__init__(_fn, **kwargs)
                self._enter_dunder_call = False

            def __call__(self, inputs, *args, **kwargs):
                self._enter_dunder_call = True
                d, _ = super(PlumbingLayer,
                             self).__call__(inputs, *args, **kwargs)
                self._enter_dunder_call = False
                return d

            def call(self, inputs, *args, **kwargs):
                d, v = super(PlumbingLayer, self).call(inputs, *args, **kwargs)
                if self._enter_dunder_call:
                    return d, v
                return d

        # User-land.
        model = keras.Sequential([
            keras.layers.InputLayer((1, )),
            PlumbingLayer(Foo),  # Makes a `Foo` object.
        ])
        # Let's ensure Keras graph history is preserved by composing the models.
        model = keras.Model(model.inputs, model(model.outputs))
        # Now we instantiate the model and verify we have a `Foo` object, not a
        # `Tensor`.
        y = model(tf.convert_to_tensor([[7.]]))
        self.assertIsInstance(y, Foo)
        # Confirm that (custom) loss sees `Foo` instance, not Tensor.
        obtained_prediction_box = [None]

        def custom_loss(y_obs, y_pred):
            del y_obs
            obtained_prediction_box[0] = y_pred
            return y_pred

        # Apparently `compile` calls the loss function enough to trigger the
        # side-effect.
        model.compile('SGD', loss=custom_loss)
        self.assertIsInstance(obtained_prediction_box[0], Foo)