Beispiel #1
0
def test_combine_input_sizes_tuples():
    assert combine_buffer_structures([BufferStructure(1, 4)]) == \
           BufferStructure(1, 4)

    assert combine_buffer_structures([BufferStructure(4, 1),
                                      BufferStructure(4, 3),
                                      BufferStructure(4, 6)])\
        == BufferStructure(4, 10)

    assert combine_buffer_structures([BufferStructure(4, 3, 2),
                                      BufferStructure(4, 3, 3),
                                      BufferStructure(4, 3, 2)]) == \
        BufferStructure(4, 3, 7)
def test_combine_input_sizes_tuples():
    assert combine_buffer_structures([BufferStructure(1, 4)]) == \
           BufferStructure(1, 4)

    assert combine_buffer_structures([BufferStructure(4, 1),
                                      BufferStructure(4, 3),
                                      BufferStructure(4, 6)])\
        == BufferStructure(4, 10)

    assert combine_buffer_structures([BufferStructure(4, 3, 2),
                                      BufferStructure(4, 3, 3),
                                      BufferStructure(4, 3, 2)]) == \
        BufferStructure(4, 3, 7)
def test_combine_input_sizes_tuple_templates():
    assert (combine_buffer_structures([BufferStructure('B', 4)]) ==
            BufferStructure('B', 4))
    assert (combine_buffer_structures([BufferStructure('B', 4),
                                       BufferStructure('B', 3)]) ==
            BufferStructure('B', 7))
    assert (combine_buffer_structures([BufferStructure('T', 'B', 4)]) ==
            BufferStructure('T', 'B', 4))
    assert (combine_buffer_structures([BufferStructure('T', 'B', 4),
                                       BufferStructure('T', 'B', 3)]) ==
            BufferStructure('T', 'B', 7))
    assert (combine_buffer_structures([BufferStructure('T', 'B', 3, 2, 4),
                                       BufferStructure('T', 'B', 3, 2, 3)]) ==
            BufferStructure('T', 'B', 3, 2, 7))
def test_combine_input_sizes_tuple_templates():
    assert (combine_buffer_structures([BufferStructure('B', 4)]) ==
            BufferStructure('B', 4))
    assert (combine_buffer_structures([BufferStructure('B', 4),
                                       BufferStructure('B', 3)]) ==
            BufferStructure('B', 7))
    assert (combine_buffer_structures([BufferStructure('T', 'B', 4)]) ==
            BufferStructure('T', 'B', 4))
    assert (combine_buffer_structures([BufferStructure('T', 'B', 4),
                                       BufferStructure('T', 'B', 3)]) ==
            BufferStructure('T', 'B', 7))
    assert (combine_buffer_structures([BufferStructure('T', 'B', 3, 2, 4),
                                       BufferStructure('T', 'B', 3, 2, 3)]) ==
            BufferStructure('T', 'B', 3, 2, 7))
Beispiel #5
0
def instantiate_layers_from_architecture(architecture):
    validate_architecture(architecture)
    layers = OrderedDict()
    connections = collect_all_connections(architecture)
    for layer_name in get_canonical_layer_order(architecture):
        layer = architecture[layer_name]
        LayerClass = get_layer_class_from_typename(layer['@type'] +
                                                   'LayerImpl')
        incoming = {c for c in connections if c.end_layer == layer_name}
        outgoing = {c for c in connections if c.start_layer == layer_name}

        input_names = {c.input_name for c in incoming}
        in_shapes = {}
        for input_name in input_names:
            incoming_out_shapes = [
                layers[c.start_layer].get_shape(
                    get_normalized_path('outputs', c.output_name))
                for c in incoming if c.input_name == input_name]

            in_shapes[input_name] = combine_buffer_structures(
                incoming_out_shapes)

        layers[layer_name] = LayerClass(layer_name, in_shapes, incoming,
                                        outgoing, **get_kwargs(layer))
    return layers
Beispiel #6
0
def test_combine_input_sizes_mismatch(sizes):
    with pytest.raises(ValueError):
        combine_buffer_structures(sizes)
def test_combine_input_sizes_mismatch(sizes):
    with pytest.raises(ValueError):
        combine_buffer_structures(sizes)