コード例 #1
0
 def _build_block(self, hp, output_node, block_type):
     if block_type == RESNET:
         return basic.ResNetBlock().build(hp, output_node)
     elif block_type == XCEPTION:
         return basic.XceptionBlock().build(hp, output_node)
     elif block_type == VANILLA:
         return basic.ConvBlock().build(hp, output_node)
コード例 #2
0
ファイル: wrapper.py プロジェクト: yifan2/autokeras
    def build(self, hp, inputs=None):
        input_node = nest.flatten(inputs)[0]
        output_node = input_node

        block_type = self.block_type or hp.Choice(
            'block_type', ['resnet', 'xception', 'vanilla'], default='vanilla')

        normalize = self.normalize
        if normalize is None:
            normalize = hp.Boolean('normalize', default=False)
        augment = self.augment
        if augment is None:
            augment = hp.Boolean('augment', default=False)
        if normalize:
            output_node = preprocessing.Normalization().build(hp, output_node)
        if augment:
            output_node = preprocessing.ImageAugmentation().build(
                hp, output_node)
        if block_type == 'resnet':
            output_node = basic.ResNetBlock().build(hp, output_node)
        elif block_type == 'xception':
            output_node = basic.XceptionBlock().build(hp, output_node)
        elif block_type == 'vanilla':
            output_node = basic.ConvBlock().build(hp, output_node)
        return output_node
コード例 #3
0
def test_resnet_block(init, build):
    utils.block_basic_exam(
        basic.ResNetBlock(),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
        ['version', 'pooling'],
    )
    assert init.called
    assert build.called
コード例 #4
0
def test_resnet_invalid_kwargs2():
    with pytest.raises(ValueError) as info:
        basic.ResNetBlock(input_shape=(10,))
    assert 'Argument "input_shape" is not' in str(info.value)
コード例 #5
0
def test_resnet_invalid_kwargs():
    with pytest.raises(ValueError) as info:
        basic.ResNetBlock(include_top=True)
    assert 'Argument "include_top" is not' in str(info.value)
コード例 #6
0
def test_resnet_block():
    utils.block_basic_exam(
        basic.ResNetBlock(),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
        ['version', 'pooling'],
    )