コード例 #1
0
ファイル: strided_window_data.py プロジェクト: kiminh/SyReNN
def test_serialize():
    """Tests that it correctly [de]serializes.
    """
    in_shape = (32, 64, np.random.randint(1, 128))
    out_channels = np.random.randint(1, 128)
    window_shape = (4, 2)
    strides = (3, 1)
    pad = (1, 3)
    window_data = StridedWindowData(in_shape, window_shape, strides, pad,
                                    out_channels)
    serialized = window_data.serialize()
    assert serialized.in_height == 32
    assert serialized.in_width == 64
    assert serialized.in_channels == in_shape[2]

    assert serialized.window_height == 4
    assert serialized.window_width == 2
    assert serialized.out_channels == out_channels

    assert serialized.stride_height == 3
    assert serialized.stride_width == 1

    assert serialized.pad_height == 1
    assert serialized.pad_width == 3

    assert StridedWindowData.deserialize(serialized).serialize() == serialized
コード例 #2
0
def test_serialize():
    """Tests that the MaxPool layer correctly [de]serializes itself.
    """
    height, width, channels = np.random.choice([8, 16, 32, 64, 128], size=3)
    window_height, window_width = np.random.choice([2, 4, 8], size=2)
    window_data = StridedWindowData(
        (height, width, channels), (window_height, window_width),
        (window_height, window_width), (0, 0), channels)

    serialized = MaxPoolLayer(window_data).serialize()
    assert serialized.WhichOneof("layer_data") == "maxpool_data"

    serialized_window_data = serialized.maxpool_data.window_data
    assert serialized_window_data == window_data.serialize()

    deserialized = MaxPoolLayer.deserialize(serialized)
    assert deserialized.serialize() == serialized

    serialized.relu_data.SetInParent()
    assert MaxPoolLayer.deserialize(serialized) is None