Esempio n. 1
0
    def _assemble(self):
        """Assemble the Blocks based on the input output nodes."""
        inputs = nest.flatten(self.inputs)
        outputs = nest.flatten(self.outputs)

        middle_nodes = []
        for input_node in inputs:
            if isinstance(input_node, input_module.TextInput):
                middle_nodes.append(blocks.TextBlock()(input_node))
            if isinstance(input_node, input_module.ImageInput):
                middle_nodes.append(blocks.ImageBlock()(input_node))
            if isinstance(input_node, input_module.StructuredDataInput):
                middle_nodes.append(blocks.StructuredDataBlock()(input_node))
            if isinstance(input_node, input_module.TimeseriesInput):
                middle_nodes.append(blocks.TimeseriesBlock()(input_node))

        # Merge the middle nodes.
        if len(middle_nodes) > 1:
            output_node = blocks.Merge()(middle_nodes)
        else:
            output_node = middle_nodes[0]

        outputs = nest.flatten(
            [output_blocks(output_node) for output_blocks in outputs])
        return graph_module.Graph(inputs=inputs, outputs=outputs)
Esempio n. 2
0
def test_image_get_config_has_all_attributes():
    block = blocks.ImageBlock()

    config = block.get_config()

    assert test_utils.get_func_args(blocks.ImageBlock.__init__).issubset(
        config.keys())
Esempio n. 3
0
def test_image_build_return_tensor():
    block = blocks.ImageBlock()

    outputs = block.build(kerastuner.HyperParameters(),
                          tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32))

    assert len(nest.flatten(outputs)) == 1
    assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
Esempio n. 4
0
def test_image_block_augment_return_tensor():
    block = blocks.ImageBlock(augment=True)

    outputs = block.build(
        keras_tuner.HyperParameters(),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
    )

    assert len(nest.flatten(outputs)) == 1
Esempio n. 5
0
def test_image_block_xception_return_tensor():
    block = blocks.ImageBlock(block_type="xception")

    outputs = block.build(
        kerastuner.HyperParameters(),
        tf.keras.Input(shape=(32, 32, 3), dtype=tf.float32),
    )

    assert len(nest.flatten(outputs)) == 1
Esempio n. 6
0
def test_image_deserialize_to_image():
    serialized_block = blocks.serialize(blocks.ImageBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.ImageBlock)
Esempio n. 7
0
 def get_block(self):
     return blocks.ImageBlock()