コード例 #1
0
ファイル: strided_window_data.py プロジェクト: kiminh/SyReNN
def test_serialize():
    """Tests that it correctly [de]serializes.
    """
    in_shape = (32, 64, np.random.randint(1, 128))
    out_channels = np.random.randint(1, 128)
    window_shape = (4, 2)
    strides = (3, 1)
    pad = (1, 3)
    window_data = StridedWindowData(in_shape, window_shape, strides, pad,
                                    out_channels)
    serialized = window_data.serialize()
    assert serialized.in_height == 32
    assert serialized.in_width == 64
    assert serialized.in_channels == in_shape[2]

    assert serialized.window_height == 4
    assert serialized.window_width == 2
    assert serialized.out_channels == out_channels

    assert serialized.stride_height == 3
    assert serialized.stride_width == 1

    assert serialized.pad_height == 1
    assert serialized.pad_width == 3

    assert StridedWindowData.deserialize(serialized).serialize() == serialized
コード例 #2
0
 def deserialize(cls, serialized):
     """Deserializes the layer from the Protobuf format.
     """
     if serialized.WhichOneof("layer_data") == "maxpool_data":
         window_data = StridedWindowData.deserialize(
             serialized.maxpool_data.window_data)
         return cls(window_data)
     return None
コード例 #3
0
ファイル: conv2d_layer.py プロジェクト: 95616ARG/SyReNN
 def deserialize(cls, serialized):
     """Deserializes from the Protobuf format.
     """
     if serialized.WhichOneof("layer_data") == "conv2d_data":
         window_data = StridedWindowData.deserialize(
             serialized.conv2d_data.window_data)
         filters = np.array(serialized.conv2d_data.filters)
         filters = filters.reshape(window_data.window_shape +
                                   (window_data.input_shape[2],
                                    window_data.out_channels,))
         biases = np.array(serialized.conv2d_data.biases)
         return cls(window_data, filters, biases)
     return None