def _assemble(self): """Assemble the Blocks based on the input output nodes.""" inputs = nest.flatten(self.inputs) outputs = nest.flatten(self.outputs) middle_nodes = [] for input_node in inputs: if isinstance(input_node, input_module.TextInput): middle_nodes.append(blocks.TextBlock()(input_node)) if isinstance(input_node, input_module.ImageInput): middle_nodes.append(blocks.ImageBlock()(input_node)) if isinstance(input_node, input_module.StructuredDataInput): middle_nodes.append(blocks.StructuredDataBlock()(input_node)) if isinstance(input_node, input_module.TimeseriesInput): middle_nodes.append(blocks.TimeseriesBlock()(input_node)) # Merge the middle nodes. if len(middle_nodes) > 1: output_node = blocks.Merge()(middle_nodes) else: output_node = middle_nodes[0] outputs = nest.flatten( [output_blocks(output_node) for output_blocks in outputs]) return graph_module.Graph(inputs=inputs, outputs=outputs)
def test_merge_get_config_has_all_attributes(): block = blocks.Merge() config = block.get_config() assert test_utils.get_func_args(blocks.Merge.__init__).issubset( config.keys())
def test_merge_single_input_return_tensor(): block = blocks.Merge() outputs = block.build( keras_tuner.HyperParameters(), keras.Input(shape=(32,), dtype=tf.float32), ) assert len(nest.flatten(outputs)) == 1
def test_merge_build_return_tensor(): block = blocks.Merge() outputs = block.build(kerastuner.HyperParameters(), [ tf.keras.Input(shape=(32, ), dtype=tf.float32), tf.keras.Input(shape=(4, 8), dtype=tf.float32) ]) assert len(nest.flatten(outputs)) == 1 assert isinstance(nest.flatten(outputs)[0], tf.Tensor)
def test_merge_inputs_with_same_shape_return_tensor(): block = blocks.Merge() outputs = block.build( kerastuner.HyperParameters(), [ tf.keras.Input(shape=(32, ), dtype=tf.float32), tf.keras.Input(shape=(32, ), dtype=tf.float32), ], ) assert len(nest.flatten(outputs)) == 1
def _assemble(self): """Assemble the Blocks based on the input output nodes.""" inputs = nest.flatten(self.inputs) outputs = nest.flatten(self.outputs) middle_nodes = [input_node.get_block()(input_node) for input_node in inputs] # Merge the middle nodes. if len(middle_nodes) > 1: output_node = blocks.Merge()(middle_nodes) else: output_node = middle_nodes[0] outputs = nest.flatten( [output_blocks(output_node) for output_blocks in outputs] ) return graph_module.Graph(inputs=inputs, outputs=outputs)
def test_merge_deserialize_to_merge(): serialized_block = blocks.serialize(blocks.Merge()) block = blocks.deserialize(serialized_block) assert isinstance(block, blocks.Merge)