Пример #1
0
    def build(self, hp, inputs=None):
        input_node = nest.flatten(inputs)[0]
        output_node = input_node

        block_type = self.block_type or hp.Choice(
            'block_type', ['resnet', 'xception', 'vanilla'], default='resnet')

        normalize = self.normalize
        if normalize is None:
            normalize = hp.Choice('normalize', [True, False], default=True)
        augment = self.augment
        if augment is None:
            augment = hp.Choice('augment', [True, False], default=True)
        if normalize:
            output_node = preprocessor.Normalization()(output_node)
        if augment:
            output_node = preprocessor.ImageAugmentation(
                seed=self.seed)(output_node)
        sub_block_name = self.name + '_' + block_type
        if block_type == 'resnet':
            output_node = block.ResNetBlock(name=sub_block_name)(output_node)
        elif block_type == 'xception':
            output_node = block.XceptionBlock(name=sub_block_name)(output_node)
        elif block_type == 'vanilla':
            output_node = block.ConvBlock(name=sub_block_name)(output_node)
        return output_node
Пример #2
0
def test_augment():
    dataset = common.generate_data(dtype='dataset')
    new_dataset = run_preprocessor(preprocessor.ImageAugmentation(seed=common.SEED),
                                   dataset,
                                   common.generate_data(dtype='dataset'),
                                   tf.float32)
    assert isinstance(new_dataset, tf.data.Dataset)
Пример #3
0
def test_augment():
    raw_images = tf.random.normal([1000, 32, 32, 3], mean=-1, stddev=4)
    augment = preprocessor.ImageAugmentation(seed=5)
    dataset = tf.data.Dataset.from_tensor_slices(raw_images)
    hp = kerastuner.HyperParameters()
    augment.set_hp(hp)
    augment.set_config(augment.get_config())
    for a in dataset:
        augment.transform(a, True)

    def map_func(x):
        return tf.py_function(functools.partial(augment.transform, fit=True),
                              inp=[x],
                              Tout=(tf.float32, ))

    new_dataset = dataset.map(map_func)
    for _ in new_dataset:
        pass
    assert isinstance(new_dataset, tf.data.Dataset)