示例#1
0
def test_set_hp():
    input_node = ak.Input((32, ))
    output_node = input_node
    output_node = ak.DenseBlock()(output_node)
    head = ak.RegressionHead()
    head.output_shape = (1, )
    output_node = head(output_node)

    graph = graph_module.HyperGraph(input_node,
                                    output_node,
                                    override_hps=[
                                        hp_module.Choice(
                                            'dense_block_1/num_layers', [6],
                                            default=6)
                                    ])
    hp = kerastuner.HyperParameters()
    plain_graph = graph.hyper_build(hp)
    plain_graph.build_keras_graph().build(hp)

    for single_hp in hp.space:
        if single_hp.name == 'dense_block_1/num_layers':
            assert len(single_hp.values) == 1
            assert single_hp.values[0] == 6
            return
    assert False
示例#2
0
def test_graph_save_load(tmp_dir):
    input1 = ak.Input()
    input2 = ak.Input()
    output1 = ak.DenseBlock()(input1)
    output2 = ak.ConvBlock()(input2)
    output = ak.Merge()([output1, output2])
    output1 = ak.RegressionHead()(output)
    output2 = ak.ClassificationHead()(output)

    graph = graph_module.HyperGraph(inputs=[input1, input2],
                                    outputs=[output1, output2],
                                    override_hps=[
                                        hp_module.Choice(
                                            'dense_block_1/num_layers', [6],
                                            default=6)
                                    ])
    path = os.path.join(tmp_dir, 'graph')
    graph.save(path)
    config = graph.get_config()
    graph = graph_module.HyperGraph.from_config(config)
    graph.reload(path)

    assert len(graph.inputs) == 2
    assert len(graph.outputs) == 2
    assert isinstance(graph.inputs[0].out_blocks[0], ak.DenseBlock)
    assert isinstance(graph.inputs[1].out_blocks[0], ak.ConvBlock)
    assert isinstance(graph.override_hps[0], hp_module.Choice)
示例#3
0
def test_input_missing():
    input_node1 = ak.Input()
    input_node2 = ak.Input()
    output_node1 = ak.DenseBlock()(input_node1)
    output_node2 = ak.DenseBlock()(input_node2)
    output_node = ak.Merge()([output_node1, output_node2])
    output_node = ak.RegressionHead()(output_node)

    with pytest.raises(ValueError) as info:
        graph_module.HyperGraph(input_node1, output_node)
    assert 'A required input is missing for HyperModel' in str(info.value)
示例#4
0
 def _meta_build(self, dataset):
     # Using functional API.
     if all([isinstance(output, base.Node) for output in self.outputs]):
         self.hyper_graph = graph.HyperGraph(inputs=self.inputs,
                                             outputs=self.outputs)
     # Using input/output API.
     elif all([isinstance(output, base.Head) for output in self.outputs]):
         self.hyper_graph = meta_model.assemble(inputs=self.inputs,
                                                outputs=self.outputs,
                                                dataset=dataset,
                                                seed=self.seed)
         self.outputs = self.hyper_graph.outputs
示例#5
0
def test_hyper_graph_cycle():
    input_node1 = ak.Input()
    input_node2 = ak.Input()
    output_node1 = ak.DenseBlock()(input_node1)
    output_node2 = ak.DenseBlock()(input_node2)
    output_node = ak.Merge()([output_node1, output_node2])
    head = ak.RegressionHead()
    output_node = head(output_node)
    head.outputs = output_node1

    with pytest.raises(ValueError) as info:
        graph_module.HyperGraph([input_node1, input_node2], output_node)
    assert 'The network has a cycle.' in str(info.value)
示例#6
0
def test_input_output_disconnect():
    input_node1 = ak.Input()
    output_node = input_node1
    _ = ak.DenseBlock()(output_node)

    input_node = ak.Input()
    output_node = input_node
    output_node = ak.DenseBlock()(output_node)
    output_node = ak.RegressionHead()(output_node)

    with pytest.raises(ValueError) as info:
        graph_module.HyperGraph(input_node1, output_node)
    assert 'Inputs and outputs not connected.' in str(info.value)
示例#7
0
def assemble(inputs, outputs, dataset, seed=None):
    """Assemble the HyperBlocks based on the dataset and input output nodes.

    # Arguments
        inputs: A list of InputNode. The input nodes of the AutoModel.
        outputs: A list of HyperHead. The heads of the AutoModel.
        dataset: tf.data.Dataset. The training dataset.
        seed: Int. Random seed.

    # Returns
        A list of HyperNode. The output nodes of the AutoModel.
    """
    inputs = nest.flatten(inputs)
    outputs = nest.flatten(outputs)

    assemblers = []
    for input_node in inputs:
        if isinstance(input_node, node.TextInput):
            assemblers.append(TextAssembler())
        if isinstance(input_node, node.ImageInput):
            assemblers.append(ImageAssembler(seed=seed))
        if isinstance(input_node, node.StructuredDataInput):
            assemblers.append(StructuredDataAssembler(seed=seed))
        if isinstance(input_node, node.TimeSeriesInput):
            assemblers.append(TimeSeriesAssembler())
    # Iterate over the dataset to fit the assemblers.
    hps = []
    for x, _ in dataset:
        for temp_x, assembler in zip(x, assemblers):
            assembler.update(temp_x)
            hps += assembler.hps

    # Assemble the model with assemblers.
    middle_nodes = []
    for input_node, assembler in zip(inputs, assemblers):
        middle_nodes.append(assembler.assemble(input_node))

    # Merge the middle nodes.
    if len(middle_nodes) > 1:
        output_node = block_module.Merge()(middle_nodes)
    else:
        output_node = middle_nodes[0]

    outputs = nest.flatten([output_blocks(output_node)
                            for output_blocks in outputs])
    hm = graph.HyperGraph(inputs, outputs, override_hps=hps)
    return hm
示例#8
0
 def _meta_build(self, dataset):
     self.hyper_graph = graph.HyperGraph(inputs=self.inputs,
                                         outputs=self.outputs)