Exemplo n.º 1
0
def test_rnn_build_return_tensor():
    block = blocks.RNNBlock()

    outputs = block.build(
        kerastuner.HyperParameters(),
        tf.keras.Input(shape=(32, 10), dtype=tf.float32),
    )

    assert len(nest.flatten(outputs)) == 1
    assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
Exemplo n.º 2
0
def test_rnn_input_shape_one_dim_error():
    block = blocks.RNNBlock()

    with pytest.raises(ValueError) as info:
        block.build(
            kerastuner.HyperParameters(),
            tf.keras.Input(shape=(32, ), dtype=tf.float32),
        )

    assert "Expect the input tensor of RNNBlock" in str(info.value)
Exemplo n.º 3
0
def test_rnn_deserialize_to_rnn():
    serialized_block = blocks.serialize(blocks.RNNBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.RNNBlock)
Exemplo n.º 4
0
def test_rnn_get_config_has_all_attributes():
    block = blocks.RNNBlock()

    config = block.get_config()

    assert test_utils.get_func_args(blocks.RNNBlock.__init__).issubset(config.keys())