示例#1
0
    def get_config(self):
        blocks = [blocks_module.serialize(block) for block in self.blocks]
        nodes = {
            str(self._node_to_id[node]): nodes_module.serialize(node)
            for node in self.inputs
        }
        override_hps = [
            kerastuner.engine.hyperparameters.serialize(hp)
            for hp in self.override_hps
        ]
        block_inputs = {
            str(block_id): [self._node_to_id[node] for node in block.inputs]
            for block_id, block in enumerate(self.blocks)
        }
        block_outputs = {
            str(block_id): [self._node_to_id[node] for node in block.outputs]
            for block_id, block in enumerate(self.blocks)
        }

        outputs = [self._node_to_id[node] for node in self.outputs]

        return {
            'override_hps': override_hps,  # List [serialized].
            'blocks': blocks,  # Dict {id: serialized}.
            'nodes': nodes,  # Dict {id: serialized}.
            'outputs': outputs,  # List of node_ids.
            'block_inputs': block_inputs,  # Dict {id: List of node_ids}.
            'block_outputs': block_outputs,  # Dict {id: List of node_ids}.
        }
示例#2
0
def block_basic_exam(block, inputs, hp_names):
    hp = kerastuner.HyperParameters()
    block = blocks.deserialize(blocks.serialize(block))
    outputs = block.build(hp, inputs)

    for hp_name in hp_names:
        assert name_in_hps(hp_name, hp)

    return outputs
示例#3
0
def test_segmentation():
    y = np.array(['a', 'a', 'c', 'b'])
    head = head_module.SegmentationHead(name='a')
    adapter = head.get_adapter()
    adapter.fit_transform(y)
    head.config_from_adapter(adapter)
    input_shape = (64, 64, 21)
    hp = kerastuner.HyperParameters()
    head = blocks.deserialize(blocks.serialize(head))
    head.build(hp, ak.Input(shape=input_shape).build())
示例#4
0
def block_basic_exam(block, inputs, hp_names):
    hp = kerastuner.HyperParameters()
    block = blocks.deserialize(blocks.serialize(block))
    utils.config_tests(block,
                       excluded_keys=[
                           'inputs', 'outputs', 'build', '_build',
                           'input_tensor', 'input_shape', 'include_top',
                           '_num_output_node'
                       ])
    outputs = block.build(hp, inputs)

    for hp_name in hp_names:
        assert name_in_hps(hp_name, hp)

    return outputs
示例#5
0
    def get_config(self):
        blocks = [blocks_module.serialize(block) for block in self.blocks]
        nodes = {
            str(self._node_to_id[node]): nodes_module.serialize(node)
            for node in self.inputs
        }
        block_inputs = {
            str(block_id): [self._node_to_id[node] for node in block.inputs]
            for block_id, block in enumerate(self.blocks)
        }
        block_outputs = {
            str(block_id): [self._node_to_id[node] for node in block.outputs]
            for block_id, block in enumerate(self.blocks)
        }

        outputs = [self._node_to_id[node] for node in self.outputs]

        return {
            "blocks": blocks,  # Dict {id: serialized}.
            "nodes": nodes,  # Dict {id: serialized}.
            "outputs": outputs,  # List of node_ids.
            "block_inputs": block_inputs,  # Dict {id: List of node_ids}.
            "block_outputs": block_outputs,  # Dict {id: List of node_ids}.
        }
示例#6
0
def test_resnet_deserialize_to_resnet():
    serialized_block = blocks.serialize(blocks.ResNetBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.ResNetBlock)
示例#7
0
def test_dense_deserialize_to_dense():
    serialized_block = blocks.serialize(blocks.DenseBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.DenseBlock)
示例#8
0
def test_embed_deserialize_to_embed():
    serialized_block = blocks.serialize(blocks.Embedding())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.Embedding)
示例#9
0
def test_conv_deserialize_to_conv():
    serialized_block = blocks.serialize(blocks.ConvBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.ConvBlock)
示例#10
0
def test_rnn_deserialize_to_rnn():
    serialized_block = blocks.serialize(blocks.RNNBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.RNNBlock)
示例#11
0
def test_bert_deserialize_to_transformer():
    serialized_block = blocks.serialize(blocks.BertBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.BertBlock)
def test_augment_deserialize_to_augment():
    serialized_block = blocks.serialize(blocks.ImageAugmentation())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.ImageAugmentation)
def test_ngram_deserialize_to_ngram():
    serialized_block = blocks.serialize(blocks.TextToNgramVector())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.TextToNgramVector)
def test_int_seq_deserialize_to_int_seq():
    serialized_block = blocks.serialize(blocks.TextToIntSequence())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.TextToIntSequence)
示例#15
0
def test_timeseries_deserialize_to_timeseries():
    serialized_block = blocks.serialize(blocks.TimeseriesBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.TimeseriesBlock)
示例#16
0
def test_temporal_deserialize_to_temporal():
    serialized_block = blocks.serialize(blocks.TemporalReduction())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.TemporalReduction)
示例#17
0
def test_structured_deserialize_to_structured():
    serialized_block = blocks.serialize(blocks.StructuredDataBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.StructuredDataBlock)
示例#18
0
def test_text_deserialize_to_text():
    serialized_block = blocks.serialize(blocks.TextBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.TextBlock)
示例#19
0
def test_spatial_deserialize_to_spatial():
    serialized_block = blocks.serialize(blocks.SpatialReduction())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.SpatialReduction)
示例#20
0
def test_xception_deserialize_to_xception():
    serialized_block = blocks.serialize(blocks.XceptionBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.XceptionBlock)
def test_cat_to_num_deserialize_to_cat_to_num():
    serialized_block = blocks.serialize(blocks.CategoricalToNumerical())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.CategoricalToNumerical)
示例#22
0
def test_image_deserialize_to_image():
    serialized_block = blocks.serialize(blocks.ImageBlock())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.ImageBlock)
示例#23
0
def test_merge_deserialize_to_merge():
    serialized_block = blocks.serialize(blocks.Merge())

    block = blocks.deserialize(serialized_block)

    assert isinstance(block, blocks.Merge)