def test_temporal_global_max_return_tensor(): block = blocks.TemporalReduction(reduction_type="global_max") outputs = block.build( kerastuner.HyperParameters(), tf.keras.Input(shape=(32, 10), dtype=tf.float32), ) assert len(nest.flatten(outputs)) == 1 assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
def test_reduction_2d_tensor_return_input_node(): block = blocks.TemporalReduction() input_node = keras.Input(shape=(32,), dtype=tf.float32) outputs = block.build( keras_tuner.HyperParameters(), input_node, ) assert len(nest.flatten(outputs)) == 1 assert nest.flatten(outputs)[0] is input_node
def test_temporal_deserialize_to_temporal(): serialized_block = blocks.serialize(blocks.TemporalReduction()) block = blocks.deserialize(serialized_block) assert isinstance(block, blocks.TemporalReduction)