Example #1
0
def test_dense_block():
    utils.block_basic_exam(
        basic.DenseBlock(),
        tf.keras.Input(shape=(32,), dtype=tf.float32),
        [
            'num_layers',
            'use_batchnorm',
        ])
Example #2
0
def test_embedding_block():
    utils.block_basic_exam(
        basic.Embedding(),
        tf.keras.Input(shape=(32,), dtype=tf.float32),
        [
            'pretraining',
            'embedding_dim',
        ])
Example #3
0
def test_resnet_block(init, build):
    utils.block_basic_exam(
        basic.ResNetBlock(),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
        ['version', 'pooling'],
    )
    assert init.called
    assert build.called
def test_merge():
    utils.block_basic_exam(
        reduction.Merge(),
        [
            tf.keras.Input(shape=(32, ), dtype=tf.float32),
            tf.keras.Input(shape=(4, 8), dtype=tf.float32),
        ],
        ['merge_type'],
    )
Example #5
0
def test_rnn_block():
    utils.block_basic_exam(
        basic.RNNBlock(),
        tf.keras.Input(shape=(32, 10), dtype=tf.float32),
        [
            'bidirectional',
            'layer_type',
            'num_layers',
        ])
Example #6
0
def test_categorical_to_numerical():
    block = preprocessing.CategoricalToNumerical()
    block.column_names = ['a']
    block.column_types = {'a': 'num'}
    utils.block_basic_exam(
        block,
        tf.keras.Input(shape=(1, ), dtype=tf.string),
        [],
    )
Example #7
0
def test_conv_block():
    utils.block_basic_exam(
        basic.ConvBlock(),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
        [
            'kernel_size',
            'num_blocks',
            'separable',
        ])
Example #8
0
def test_xception_block():
    utils.block_basic_exam(
        basic.XceptionBlock(),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
        [
            'activation',
            'initial_strides',
            'num_residual_blocks',
            'pooling',
        ])
Example #9
0
def test_image_block():
    utils.block_basic_exam(
        wrapper.ImageBlock(normalize=None, augment=None),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
        [
            'block_type',
            'normalize',
            'augment',
        ],
    )
Example #10
0
def test_timeseries_block():
    block = wrapper.TimeseriesBlock()
    block.column_names = ['0', '1']
    block.column_types = {
        '0': adapters.NUMERICAL,
        '1': adapters.NUMERICAL,
    }
    outputs = utils.block_basic_exam(
        block,
        tf.keras.Input(shape=(32, 2), dtype=tf.float32),
        [],
    )
    assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
Example #11
0
def test_structured_data_block():
    block = wrapper.StructuredDataBlock()
    block.column_names = ['0', '1']
    block.column_types = {
        '0': adapters.NUMERICAL,
        '1': adapters.NUMERICAL,
    }
    outputs = utils.block_basic_exam(
        block,
        tf.keras.Input(shape=(2, ), dtype=tf.string),
        [],
    )
    assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
def test_spatial_reduction():
    utils.block_basic_exam(
        reduction.SpatialReduction(),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
        ['reduction_type'],
    )
def test_temporal_reduction():
    utils.block_basic_exam(
        reduction.TemporalReduction(),
        tf.keras.Input(shape=(32, 10), dtype=tf.float32),
        ['reduction_type'],
    )
Example #14
0
def test_text_block():
    utils.block_basic_exam(
        wrapper.TextBlock(),
        tf.keras.Input(shape=(1, ), dtype=tf.string),
        ['vectorizer'],
    )
Example #15
0
def test_text_to_int_sequence():
    utils.block_basic_exam(
        preprocessing.TextToIntSequence(),
        tf.keras.Input(shape=(1, ), dtype=tf.string),
        ['output_sequence_length'],
    )
Example #16
0
def test_text_to_ngram_vector():
    utils.block_basic_exam(
        preprocessing.TextToNgramVector(),
        tf.keras.Input(shape=(1, ), dtype=tf.string),
        ['ngrams'],
    )
Example #17
0
def test_image_augmentation():
    utils.block_basic_exam(
        preprocessing.ImageAugmentation(),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
        ['vertical_flip', 'horizontal_flip'],
    )
Example #18
0
def test_resnet_block():
    utils.block_basic_exam(
        basic.ResNetBlock(),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
        ['version', 'pooling'],
    )