def test_embedding_block(): utils.block_basic_exam( basic.Embedding(), tf.keras.Input(shape=(32,), dtype=tf.float32), [ 'pretraining', 'embedding_dim', ])
def test_resnet_block(init, build): utils.block_basic_exam( basic.ResNetBlock(), tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32), ['version', 'pooling'], ) assert init.called assert build.called
def test_dense_block(): utils.block_basic_exam( basic.DenseBlock(), tf.keras.Input(shape=(32,), dtype=tf.float32), [ 'num_layers', 'use_batchnorm', ])
def test_merge(): utils.block_basic_exam( reduction.Merge(), [ tf.keras.Input(shape=(32, ), dtype=tf.float32), tf.keras.Input(shape=(4, 8), dtype=tf.float32), ], ['merge_type'], )
def test_rnn_block(): utils.block_basic_exam( basic.RNNBlock(), tf.keras.Input(shape=(32, 10), dtype=tf.float32), [ 'bidirectional', 'layer_type', 'num_layers', ])
def test_conv_block(): utils.block_basic_exam( basic.ConvBlock(), tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32), [ 'kernel_size', 'num_blocks', 'separable', ])
def test_image_block(): utils.block_basic_exam( wrapper.ImageBlock(normalize=None, augment=None), tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32), [ 'block_type', 'normalize', 'augment', ], )
def test_xception_block(init, build): utils.block_basic_exam( basic.XceptionBlock(), tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32), [ 'activation', 'initial_strides', 'num_residual_blocks', 'pooling', ]) assert init.called assert build.called
def test_timeseries_block(): block = wrapper.TimeseriesBlock() block.column_names = ['0', '1'] block.column_types = { '0': adapters.NUMERICAL, '1': adapters.NUMERICAL, } outputs = utils.block_basic_exam( block, tf.keras.Input(shape=(32, 2), dtype=tf.float32), [], ) assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
def test_structured_data_block(): block = wrapper.StructuredDataBlock() block.column_names = ['0', '1'] block.column_types = { '0': adapters.NUMERICAL, '1': adapters.NUMERICAL, } outputs = utils.block_basic_exam( block, tf.keras.Input(shape=(2, ), dtype=tf.string), [], ) assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
def test_spatial_reduction(): utils.block_basic_exam( reduction.SpatialReduction(), tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32), ['reduction_type'], )
def test_temporal_reduction(): utils.block_basic_exam( reduction.TemporalReduction(), tf.keras.Input(shape=(32, 10), dtype=tf.float32), ['reduction_type'], )
def test_text_block(): utils.block_basic_exam( wrapper.TextBlock(), tf.keras.Input(shape=(1, ), dtype=tf.string), ['vectorizer'], )
def test_image_augmentation(): utils.block_basic_exam( preprocessing.ImageAugmentation(), tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32), ['vertical_flip', 'horizontal_flip'], )
def test_text_to_ngram_vector(): utils.block_basic_exam( preprocessing.TextToNgramVector(), tf.keras.Input(shape=(1, ), dtype=tf.string), ['ngrams'], )