def test_cast_to_float32_return_float32_tensor(tmp_path):
    layer = layer_module.CastToFloat32()

    tensor = layer(tf.constant(["0.3"], dtype=tf.string))

    assert tf.float32 == tensor.dtype
Beispiel #2
0
 def build(self, hp, inputs=None):
     input_node = nest.flatten(inputs)[0]
     return keras_layers.CastToFloat32()(input_node)