def test_serialize(): """Tests that it correctly [de]serializes. """ in_shape = (32, 64, np.random.randint(1, 128)) out_channels = np.random.randint(1, 128) window_shape = (4, 2) strides = (3, 1) pad = (1, 3) window_data = StridedWindowData(in_shape, window_shape, strides, pad, out_channels) serialized = window_data.serialize() assert serialized.in_height == 32 assert serialized.in_width == 64 assert serialized.in_channels == in_shape[2] assert serialized.window_height == 4 assert serialized.window_width == 2 assert serialized.out_channels == out_channels assert serialized.stride_height == 3 assert serialized.stride_width == 1 assert serialized.pad_height == 1 assert serialized.pad_width == 3 assert StridedWindowData.deserialize(serialized).serialize() == serialized
def test_serialize(): """Tests that the MaxPool layer correctly [de]serializes itself. """ height, width, channels = np.random.choice([8, 16, 32, 64, 128], size=3) window_height, window_width = np.random.choice([2, 4, 8], size=2) window_data = StridedWindowData( (height, width, channels), (window_height, window_width), (window_height, window_width), (0, 0), channels) serialized = MaxPoolLayer(window_data).serialize() assert serialized.WhichOneof("layer_data") == "maxpool_data" serialized_window_data = serialized.maxpool_data.window_data assert serialized_window_data == window_data.serialize() deserialized = MaxPoolLayer.deserialize(serialized) assert deserialized.serialize() == serialized serialized.relu_data.SetInParent() assert MaxPoolLayer.deserialize(serialized) is None