def test_rnn_build_return_tensor(): block = blocks.RNNBlock() outputs = block.build( kerastuner.HyperParameters(), tf.keras.Input(shape=(32, 10), dtype=tf.float32), ) assert len(nest.flatten(outputs)) == 1 assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
def test_rnn_input_shape_one_dim_error(): block = blocks.RNNBlock() with pytest.raises(ValueError) as info: block.build( kerastuner.HyperParameters(), tf.keras.Input(shape=(32, ), dtype=tf.float32), ) assert "Expect the input tensor of RNNBlock" in str(info.value)
def test_rnn_deserialize_to_rnn(): serialized_block = blocks.serialize(blocks.RNNBlock()) block = blocks.deserialize(serialized_block) assert isinstance(block, blocks.RNNBlock)
def test_rnn_get_config_has_all_attributes(): block = blocks.RNNBlock() config = block.get_config() assert test_utils.get_func_args(blocks.RNNBlock.__init__).issubset(config.keys())