def test_embed_get_config_has_all_attributes(): block = blocks.Embedding() config = block.get_config() assert utils.get_func_args( blocks.Embedding.__init__).issubset(config.keys())
def test_embed_build_return_tensor(): block = blocks.Embedding() outputs = block.build(keras_tuner.HyperParameters(), tf.keras.Input(shape=(32, ), dtype=tf.float32)) assert len(nest.flatten(outputs)) == 1
def test_embed_deserialize_to_embed(): serialized_block = blocks.serialize(blocks.Embedding()) block = blocks.deserialize(serialized_block) assert isinstance(block, blocks.Embedding)